From 66191ea976e6477d1c2fca8bcf6c269a5dbfcdf4 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 5 Jun 2026 03:37:02 -0400 Subject: [PATCH 1/2] [Arith] Make Analyzer a tvm-ffi Object --- include/tvm/arith/analyzer.h | 165 ++++++++++------ include/tvm/arith/int_set.h | 9 +- include/tvm/arith/iter_affine_map.h | 8 +- include/tvm/ir/scope_stack.h | 2 +- include/tvm/ir/with_context.h | 4 +- include/tvm/relax/analysis.h | 90 +++++++-- include/tvm/relax/block_builder.h | 2 +- include/tvm/relax/dataflow_pattern.h | 3 +- .../tvm/relax/distributed/axis_group_graph.h | 6 +- include/tvm/relax/utils.h | 2 +- include/tvm/s_tir/analysis.h | 6 +- include/tvm/tirx/analysis.h | 4 - include/tvm/tirx/index_map.h | 88 +++++++-- include/tvm/topi/detail/constant_utils.h | 2 +- include/tvm/topi/nn.h | 4 +- include/tvm/topi/nn/bnn.h | 2 +- include/tvm/topi/nn/dilate.h | 2 +- include/tvm/topi/nn/pooling.h | 8 +- include/tvm/topi/transform.h | 21 +- python/tvm/arith/analyzer.py | 60 +++--- python/tvm/tirx/function.py | 26 ++- src/arith/analyzer.cc | 180 ++++++++---------- src/arith/bound_deducer.cc | 8 +- src/arith/canonical_simplify.cc | 10 +- src/arith/conjunctive_normal_form.cc | 26 +-- src/arith/conjunctive_normal_form.h | 3 +- src/arith/const_int_bound.cc | 6 +- src/arith/detect_linear_equation.cc | 8 +- src/arith/domain_touched.cc | 2 +- src/arith/int_constraints.cc | 36 ++-- src/arith/int_set.cc | 106 ++++++----- src/arith/interval_set.h | 4 +- src/arith/ir_mutator_with_analyzer.cc | 10 +- src/arith/ir_mutator_with_analyzer.h | 4 +- src/arith/ir_visitor_with_analyzer.cc | 24 +-- src/arith/ir_visitor_with_analyzer.h | 2 +- src/arith/iter_affine_map.cc | 54 +++--- src/arith/modular_set.cc | 6 +- src/arith/presburger_set.cc | 4 +- src/arith/rewrite_simplify.cc | 2 +- src/arith/rewrite_simplify.h | 2 +- src/arith/solve_linear_equation.cc | 22 +-- src/arith/solve_linear_inequality.cc | 62 +++--- src/relax/analysis/layout_transformation.cc | 8 +- src/relax/analysis/shape_analysis.cc | 4 +- src/relax/analysis/struct_info_analysis.cc | 99 ++++++---- src/relax/analysis/tir_op_pattern_kind.cc | 18 +- src/relax/distributed/axis_group_graph.cc | 18 +- .../lower_global_view_to_local_view.cc | 6 +- src/relax/ir/block_builder.cc | 8 +- src/relax/ir/dataflow_block_rewriter.cc | 6 +- src/relax/ir/dataflow_matcher.cc | 9 +- src/relax/op/ccl/ccl.cc | 2 +- src/relax/op/distributed/distributed.cc | 4 +- src/relax/op/distributed/linear_algebra.cc | 2 +- src/relax/op/nn/attention.cc | 2 +- src/relax/op/nn/convolution.cc | 12 +- src/relax/op/nn/nn.cc | 10 +- src/relax/op/nn/pooling.cc | 6 +- src/relax/op/op.cc | 2 +- src/relax/op/op_common.cc | 6 +- src/relax/op/op_common.h | 2 +- src/relax/op/tensor/create.cc | 6 +- src/relax/op/tensor/index.cc | 2 +- src/relax/op/tensor/linear_algebra.cc | 2 +- src/relax/op/tensor/manipulate.cc | 24 +-- src/relax/op/tensor/sampling.cc | 2 +- src/relax/op/tensor/ternary.cc | 2 +- src/relax/op/vision/nms.cc | 2 +- src/relax/transform/adjust_matmul_order.cc | 27 +-- src/relax/transform/alter_op_impl.cc | 8 +- src/relax/transform/bind_params.cc | 5 +- .../transform/combine_parallel_matmul.cc | 2 +- src/relax/transform/fuse_tir.cc | 6 +- .../transform/remove_unused_parameters.cc | 2 +- .../transform/rewrite_dataflow_reshape.cc | 2 +- .../transform/split_call_tir_by_pattern.cc | 2 +- .../transform/static_plan_block_memory.cc | 25 +-- src/relax/utils.cc | 6 +- src/s_tir/analysis/estimate_flops.cc | 4 +- src/s_tir/analysis/identify_memcpy.cc | 13 +- src/s_tir/analysis/oob_checker.cc | 8 +- .../analysis/sblock_access_region_detector.cc | 2 +- .../backend/adreno/inject_texture_alloc.cc | 4 +- src/s_tir/data_layout.cc | 6 +- .../feature_extractor/per_store_feature.cc | 22 +-- .../disallow_async_strided_mem_copy.cc | 4 +- .../meta_schedule/postproc/rewrite_layout.cc | 2 +- .../rewrite_parallel_vectorize_unroll.cc | 2 +- .../multi_level_tiling_wide_vector.cc | 2 +- src/s_tir/schedule/analysis.h | 14 +- src/s_tir/schedule/analysis/analysis.cc | 32 ++-- src/s_tir/schedule/analysis/layout.cc | 9 +- src/s_tir/schedule/analysis/reducer.cc | 6 +- src/s_tir/schedule/concrete_schedule.cc | 4 +- src/s_tir/schedule/concrete_schedule.h | 2 +- src/s_tir/schedule/ir_comparator.cc | 24 +-- .../primitive/annotate_buffer_access.cc | 4 +- .../schedule/primitive/blockize_tensorize.cc | 29 +-- src/s_tir/schedule/primitive/cache_index.cc | 8 +- .../schedule/primitive/cache_index_helpers.cc | 2 +- .../schedule/primitive/cache_read_write.cc | 20 +- src/s_tir/schedule/primitive/compute_at.cc | 39 ++-- .../schedule/primitive/compute_inline.cc | 34 ++-- .../schedule/primitive/decompose_padding.cc | 21 +- .../primitive/layout_transformation.cc | 54 +++--- .../schedule/primitive/loop_transformation.cc | 21 +- src/s_tir/schedule/primitive/pad_einsum.cc | 10 +- src/s_tir/schedule/primitive/read_write_at.cc | 4 +- .../schedule/primitive/rolling_buffer.cc | 4 +- src/s_tir/schedule/state.cc | 22 ++- src/s_tir/schedule/traced_schedule.cc | 4 +- src/s_tir/schedule/transform.cc | 4 +- src/s_tir/schedule/transform.h | 4 +- src/s_tir/transform/bound_checker.cc | 4 +- src/s_tir/transform/canonicalize_loop.cc | 4 +- src/s_tir/transform/compact_buffer_region.cc | 29 +-- src/s_tir/transform/hoist_expression.cc | 4 +- src/s_tir/transform/inject_permuted_layout.cc | 4 +- .../transform/inject_software_pipeline.cc | 36 ++-- src/s_tir/transform/inject_virtual_thread.cc | 4 +- src/s_tir/transform/loop_partition.cc | 50 ++--- src/s_tir/transform/lower_async_dma.cc | 6 +- .../transform/lower_cross_thread_reduction.cc | 4 +- src/s_tir/transform/lower_match_buffer.cc | 11 +- src/s_tir/transform/lower_thread_allreduce.cc | 2 +- src/s_tir/transform/memhammer_coalesce.cc | 8 +- .../transform/memhammer_intermediate_stage.cc | 2 +- .../transform/memhammer_lower_auto_copy.cc | 8 +- .../transform/memhammer_tensorcore_rewrite.cc | 8 +- .../transform/renormalize_split_pattern.cc | 4 +- .../transform/transform_mma_buffer_layout.cc | 4 +- src/s_tir/transform/unify_thread_binding.cc | 4 +- .../using_assume_to_reduce_branches.cc | 4 +- src/target/cuda/codegen_cuda.cc | 10 +- src/target/hexagon/llvm/codegen_hexagon.cc | 2 +- src/target/llvm/codegen_cpu.cc | 6 +- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/codegen_llvm.h | 2 +- src/target/opencl/intrin_rule_opencl.cc | 2 +- src/target/source/codegen_c.cc | 4 +- src/target/vulkan/codegen_spirv.cc | 2 +- src/target/vulkan/codegen_spirv.h | 2 +- src/target/webgpu/codegen_webgpu.cc | 2 +- src/te/operation/create_primfunc.cc | 12 +- src/te/operation/scan_op.cc | 2 +- src/tirx/analysis/exec_context.cc | 2 +- src/tirx/ir/buffer.cc | 16 +- src/tirx/ir/exec_scope.cc | 18 +- src/tirx/ir/index_map.cc | 107 +++++++---- src/tirx/ir/layout/axis_registry.cc | 14 +- src/tirx/ir/layout/swizzle_layout.cc | 2 +- src/tirx/ir/layout/tile_canonicalize.cc | 2 +- src/tirx/ir/layout/tile_core.cc | 12 +- src/tirx/ir/layout/tile_direct_sum_ops.cc | 12 +- src/tirx/ir/layout/tile_slice.cc | 36 ++-- src/tirx/ir/layout/tile_tile_ops.cc | 42 ++-- src/tirx/ir/stmt.cc | 4 +- src/tirx/script/builder/ir.cc | 6 +- src/tirx/transform/flatten_buffer.cc | 4 +- src/tirx/transform/ir_utils.cc | 12 +- src/tirx/transform/lower_intrin.cc | 7 +- src/tirx/transform/lower_tirx_cleanup.cc | 6 +- src/tirx/transform/lower_warp_memory.cc | 16 +- src/tirx/transform/narrow_datatype.cc | 10 +- src/tirx/transform/remove_no_op.cc | 12 +- src/tirx/transform/remove_no_op.h | 2 +- src/tirx/transform/stmt_simplify.cc | 8 +- src/tirx/transform/stmt_simplify.h | 2 +- src/tirx/transform/storage_rewrite.cc | 12 +- src/tirx/transform/tile_primitive_dispatch.cc | 12 +- src/tirx/transform/tvm_ffi_binder.cc | 14 +- src/tirx/transform/unroll_loop.cc | 2 +- src/tirx/transform/vectorize_loop.cc | 14 +- tests/cpp/arith_simplify_test.cc | 30 ++- tests/cpp/threading_backend_test.cc | 1 + .../arith/test_arith_analyzer_object.py | 59 ++++++ tests/python/tirx-base/test_tir_index_map.py | 64 ++++++- 178 files changed, 1524 insertions(+), 1172 deletions(-) create mode 100644 tests/python/arith/test_arith_analyzer_object.py diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3d9e8ebbf93f..924cc299270a 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -25,6 +25,7 @@ #define TVM_ARITH_ANALYZER_H_ #include +#include #include #include #include @@ -32,6 +33,7 @@ #include #include #include +#include #include namespace tvm { @@ -48,8 +50,10 @@ namespace arith { // another analyzer. //------------------------------------------------------- -// Forward declare Analyzer +// Forward declare the analyzer object and its reference handle. +class AnalyzerObj; class Analyzer; +class ConstraintContext; using tirx::Var; @@ -172,9 +176,9 @@ class ConstIntBoundAnalyzer { TVM_DLL bool IsBound(const Var& var) const; private: - friend class Analyzer; + friend class AnalyzerObj; friend class ConstraintContext; - explicit ConstIntBoundAnalyzer(Analyzer* parent); + explicit ConstIntBoundAnalyzer(AnalyzerObj* parent); TVM_DLL ~ConstIntBoundAnalyzer(); /*! * \brief Update the internal state to enter constraint. @@ -251,9 +255,9 @@ class ModularSetAnalyzer { TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false); private: - friend class Analyzer; + friend class AnalyzerObj; friend class ConstraintContext; - explicit ModularSetAnalyzer(Analyzer* parent); + explicit ModularSetAnalyzer(AnalyzerObj* parent); TVM_DLL ~ModularSetAnalyzer(); /*! * \brief Update the internal state to enter constraint. @@ -397,16 +401,17 @@ class RewriteSimplifier { * Note: To maintain accurate usage counters, `Analyzer` instances * should be re-used wherever possible. For example, TIR * transformations should declare a single `Analyzer` that is used - * throughout the pass, and utility functions should receive an - * `Analyzer*` from their calling scope. + * throughout the pass. Internal helper functions that only borrow + * the analyzer temporarily may receive the underlying `AnalyzerObj*` + * from their calling scope. */ TVM_DLL void SetMaximumRewriteSteps(int64_t maximum); private: - friend class Analyzer; + friend class AnalyzerObj; friend class ConstraintContext; friend class CanonicalSimplifier; - explicit RewriteSimplifier(Analyzer* parent); + explicit RewriteSimplifier(AnalyzerObj* parent); TVM_DLL ~RewriteSimplifier(); class Impl; /*! \brief Internal impl */ @@ -435,9 +440,9 @@ class CanonicalSimplifier { TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false); private: - friend class Analyzer; + friend class AnalyzerObj; friend class ConstraintContext; - explicit CanonicalSimplifier(Analyzer* parent); + explicit CanonicalSimplifier(AnalyzerObj* parent); TVM_DLL ~CanonicalSimplifier(); class Impl; /*! \brief Internal impl */ @@ -520,7 +525,7 @@ class TransitiveComparisonAnalyzer { TVM_DLL std::function EnterConstraint(const PrimExpr& constraint); private: - friend class Analyzer; + friend class AnalyzerObj; friend class ConstraintContext; TransitiveComparisonAnalyzer(); TVM_DLL ~TransitiveComparisonAnalyzer(); @@ -529,45 +534,6 @@ class TransitiveComparisonAnalyzer { std::unique_ptr impl_; }; -/*! - * \brief Constraint context. - * - * \code - * - * Var("x"); - * arith::Analyzer analyzer; - * { - * With scope(&analyzer, x % 3 == 0); - * TVM_FFI_ICHECK_EQ(analyzer.modular_set(x)->coeff, 3); - * } - * // constraint no longer in effect. - * TVM_FFI_ICHECK_NE(analyzer.modular_set(x)->coeff, 3); - * - * \endcode - */ -class ConstraintContext { - private: - // declare friend to enable with. - friend class With; - /*! - * \brief Construct a constraint context. - * \param analyzer The analyzer. - * \param constraint The constraint to be applied. - */ - ConstraintContext(Analyzer* analyzer, PrimExpr constraint) - : analyzer_(analyzer), constraint_(constraint) {} - // enter the scope. - void EnterWithScope(); - // exit the scope. - void ExitWithScope(); - /*! \brief The analyzer */ - Analyzer* analyzer_; - /*! \brief The constraint */ - PrimExpr constraint_; - /*! \brief functions to be called in recovery */ - std::vector> recovery_functions_; -}; - /*! * \brief Integer set analyzer. */ @@ -614,8 +580,8 @@ class IntSetAnalyzer { std::function EnterConstraint(const PrimExpr& constraint); private: - friend class Analyzer; - explicit IntSetAnalyzer(Analyzer* parent); + friend class AnalyzerObj; + explicit IntSetAnalyzer(AnalyzerObj* parent); TVM_DLL ~IntSetAnalyzer(); class Impl; /*! \brief Internal impl */ @@ -632,13 +598,8 @@ class IntSetAnalyzer { * If the analyzer uses memoization, we need to clear the internal * cache when information about a Var has been overridden. */ -class TVM_DLL Analyzer { +class TVM_DLL AnalyzerObj : public ffi::Object { public: - /* - * Disable copy constructor. - */ - Analyzer(const Analyzer&) = delete; - Analyzer& operator=(const Analyzer&) = delete; /*! \brief sub-analyzer: const integer bound */ ConstIntBoundAnalyzer const_int_bound; /*! \brief sub-analyzer: modular set */ @@ -652,7 +613,7 @@ class TVM_DLL Analyzer { /*! \brief sub-analyzer transitive comparisons */ TransitiveComparisonAnalyzer transitive_comparisons; /*! \brief constructor */ - Analyzer(); + AnalyzerObj(); /*! * \brief Mark the value as non-negative value globally in analyzer. * @@ -785,6 +746,90 @@ class TVM_DLL Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ PrimExpr Simplify(const PrimExpr& expr, int steps = 2); + + /*! + * \brief Analyzer methods update facts, constraints, caches, and stats. + * + * Marking the object mutable makes the `Analyzer` ObjectRef expose a + * non-const `operator->`, so APIs can take `const Analyzer&` while still + * allowing calls such as `analyzer->Bind(...)`. + * `const Analyzer&` keeps the handle itself from being rebound; it does + * not make the underlying AnalyzerObj immutable. + */ + static constexpr bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.Analyzer", AnalyzerObj, ffi::Object); +}; + +/*! + * \brief Managed reference to AnalyzerObj. + * + * Analyzer is a lightweight, reference-counted handle around a heap-allocated + * AnalyzerObj. Because it is now a first-class FFI object, an Analyzer can be + * passed across the tvm-ffi boundary (e.g. handed from Python into a C++ pass) + * and shared, so that accumulated bindings/constraints persist across calls. + * Copying an Analyzer copies the handle, and both handles share the same + * mutable AnalyzerObj state. + * This is not a deep copy of analyzer facts or caches. + * + * \sa AnalyzerObj + */ +class Analyzer : public ffi::ObjectRef { + public: + /*! \brief Default-construct a fresh analyzer (allocates an AnalyzerObj). */ + Analyzer() : Analyzer(ffi::make_object()) {} + explicit Analyzer(ffi::ObjectPtr n) : ffi::ObjectRef(std::move(n)) { + TVM_FFI_ICHECK(this->get() != nullptr); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Analyzer, ffi::ObjectRef, AnalyzerObj); +}; + +/*! + * \brief Constraint context. + * + * \code + * + * Var x("x"); + * arith::Analyzer analyzer; + * { + * With scope(analyzer, tvm::floormod(x, 3) == 0); + * TVM_FFI_ICHECK_EQ(analyzer->modular_set(x)->coeff, 3); + * } + * // constraint no longer in effect. + * TVM_FFI_ICHECK_NE(analyzer->modular_set(x)->coeff, 3); + * + * \endcode + */ +class ConstraintContext { + private: + // declare friend to enable with. + friend class With; + /*! + * \brief Construct a constraint context. + * \param analyzer The analyzer whose context is updated. The context + * keeps a reference to the analyzer while the scope is active. + * \param constraint The constraint to be applied. + */ + ConstraintContext(const Analyzer& analyzer, PrimExpr constraint) + : analyzer_(analyzer), constraint_(constraint) {} + /*! + * \brief Construct a constraint context from a borrowed analyzer object. + * \param analyzer The borrowed analyzer object. + * \param constraint The constraint to be applied. + * + * This overload is for internal callers that already operate on AnalyzerObj*. + */ + ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint) + : ConstraintContext(ffi::GetRef(analyzer), std::move(constraint)) {} + // enter the scope. + void EnterWithScope(); + // exit the scope. + void ExitWithScope(); + /*! \brief Analyzer kept alive while the context is active. */ + Analyzer analyzer_; + /*! \brief The constraint */ + PrimExpr constraint_; + /*! \brief functions to be called in recovery */ + std::vector> recovery_functions_; }; } // namespace arith diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 89f4b9f78979..662a94eceeae 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -36,6 +36,7 @@ using tirx::IterVar; using tirx::Var; using tirx::VarNode; +class AnalyzerObj; class Analyzer; //----------------------------------------------- @@ -96,7 +97,7 @@ class IntSet : public ffi::ObjectRef { * \param ana Analyzer used in the proof. * \return Whether we can prove it is a single point */ - bool CanProveSinglePoint(Analyzer* ana) const; + bool CanProveSinglePoint(const Analyzer& ana) const; // TODO(tvm-team): update all CanProve to explicitly take // analyzer to encourage more analyzer reuse /*! \return Whether the set is proved to be bigger than 0 */ @@ -302,7 +303,7 @@ ffi::Map AsIntSet(const ffi::Map& var_dom); */ TVM_DLL ffi::Optional> EstimateRegionStrictBound( const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, - arith::Analyzer* analyzer); + const arith::Analyzer& analyzer); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate. @@ -316,7 +317,7 @@ TVM_DLL ffi::Optional> EstimateRegionStrictBound( */ TVM_DLL ffi::Optional> EstimateRegionLowerBound( const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, - arith::Analyzer* analyzer); + const arith::Analyzer& analyzer); /*! * \brief Analyze the region with affine map, given the domain of variables and their predicate @@ -331,7 +332,7 @@ TVM_DLL ffi::Optional> EstimateRegionLowerBound( TVM_DLL ffi::Array EstimateRegionUpperBound(const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, - arith::Analyzer* analyzer); + const arith::Analyzer& analyzer); } // namespace arith } // namespace tvm diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index ede0e04d59d0..4e9ac512aac9 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -306,7 +306,7 @@ class IterMapResult : public ffi::ObjectRef { */ IterMapResult DetectIterMap(const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer, + IterMapLevel check_level, const arith::Analyzer& analyzer, bool simplify_trivial_iterators = true); /*! @@ -323,7 +323,7 @@ IterMapResult DetectIterMap(const ffi::Array& indices, ffi::Array IterMapSimplify(const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, IterMapLevel check_level, - arith::Analyzer* analyzer, + const arith::Analyzer& analyzer, bool simplify_trivial_iterators = true); /*! @@ -380,7 +380,7 @@ ffi::Array> SubspaceDivide(const ffi::Array& bind const ffi::Map& input_iters, const ffi::Array& sub_iters, const PrimExpr& predicate, IterMapLevel check_level, - arith::Analyzer* analyzer, + const arith::Analyzer& analyzer, bool simplify_trivial_iterators = true); /*! @@ -407,7 +407,7 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr); * \note This function is useful to detect iterator stride patterns. */ IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input_iters, - arith::Analyzer* analyzer); + const arith::Analyzer& analyzer); } // namespace arith } // namespace tvm diff --git a/include/tvm/ir/scope_stack.h b/include/tvm/ir/scope_stack.h index 694d35e19ec1..b5ea10656f2f 100644 --- a/include/tvm/ir/scope_stack.h +++ b/include/tvm/ir/scope_stack.h @@ -44,7 +44,7 @@ namespace tvm { * * // In VisitStmt_(ForNode): * return constraints.WithNewScope([&]() -> Stmt { - * constraints.Current().Emplace(&analyzer, condition); + * constraints.Current().Emplace(analyzer, condition); * return StmtExprMutator::VisitStmt_(op); * }); * \endcode diff --git a/include/tvm/ir/with_context.h b/include/tvm/ir/with_context.h index 1b7502f33b2c..5c7fe6d0f26b 100644 --- a/include/tvm/ir/with_context.h +++ b/include/tvm/ir/with_context.h @@ -103,8 +103,8 @@ class With { * * \code * WithGroup group; - * group.Emplace(&analyzer, cond1); // constructs and enters - * group.Emplace(&analyzer, cond2); // constructs and enters + * group.Emplace(analyzer, cond1); // constructs and enters + * group.Emplace(analyzer, cond2); // constructs and enters * // destructor: exits cond2, then cond1 * \endcode * diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 3283f3627a3c..c71677b9341d 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -55,7 +55,7 @@ namespace relax { * two shapes equals to each other during runtime. */ TVM_DLL bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Array& rhs, - arith::Analyzer* ana); + const arith::Analyzer& ana); /*! * \brief Can prove the two symbolic shape expressions equals to each other. @@ -68,7 +68,7 @@ TVM_DLL bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Arra * if result is false, there is still possibility that * two shapes equals to each other during runtime. */ -TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, arith::Analyzer* ana); +TVM_DLL bool CanProveShapeEqual(const Expr& lhs, const Expr& rhs, const arith::Analyzer& ana); //----------------------------------- // Foundational StructInfo analysis @@ -92,13 +92,22 @@ TVM_DLL StructInfo StructInfoFromType(const Type& type); * \param finfo The function struct info. * \param call The call expression to be derived. * \param ctx The builder context. - * \param ana Optional context analyzer to prove symbolic expression equality. * \return The derived struct info of the call. * \note call->op field is ignored during derivation and we only rely on information * presented by func_sinfo. */ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, - const BlockBuilder& ctx, arith::Analyzer* ana = nullptr); + const BlockBuilder& ctx); +/*! + * \brief Derive the call's ret value struct info using a caller-provided analyzer. + * \param finfo The function struct info. + * \param call The call expression to be derived. + * \param ctx The builder context. + * \param ana Context analyzer to prove symbolic expression equality. + * \return The derived struct info of the call. + */ +TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, const arith::Analyzer& ana); /*! * \brief Erase the info to a corresponding more coarse grained @@ -152,15 +161,29 @@ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Ca * \param f_var_map callback function to specify * whether a var is defined in the target scope and the value it maps to, * return nullopt if var is undefined. - * \param ana Optional context analyzer to prove symbolic expression equality. * * \return the corresponding erased struct info. */ TVM_DLL StructInfo EraseToWellDefined( const StructInfo& info, std::function(const tirx::Var& var)> f_shape_var_map = nullptr, - std::function(const Var& var)> f_var_map = nullptr, - arith::Analyzer* ana = nullptr); + std::function(const Var& var)> f_var_map = nullptr); +/*! + * \brief EraseToWellDefined overload using a caller-provided analyzer. + * \param info The struct info. + * \param f_shape_var_map callback function to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param f_var_map callback function to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Context analyzer to prove symbolic expression equality. + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo EraseToWellDefined( + const StructInfo& info, + std::function(const tirx::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, const arith::Analyzer& ana); /*! * \brief EraseToWellDefined variant with map. @@ -171,13 +194,27 @@ TVM_DLL StructInfo EraseToWellDefined( * \param var_map map to specify * whether a var is defined in the target scope and the value it maps to, * return nullopt if var is undefined. - * \param ana Optional context analyzer to prove symbolic expression equality. * * \return the corresponding erased struct info. */ TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, - ffi::Map var_map, arith::Analyzer* ana = nullptr); + ffi::Map var_map); +/*! + * \brief EraseToWellDefined map overload using a caller-provided analyzer. + * \param info The struct info. + * \param shape_var_map map to specify + * whether a symbolic shape var is defined and the value it maps to, + * return nullopt if var is undefined. + * \param var_map map to specify + * whether a var is defined in the target scope and the value it maps to, + * return nullopt if var is undefined. + * \param ana Context analyzer to prove symbolic expression equality. + * \return the corresponding erased struct info. + */ +TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, + ffi::Map shape_var_map, + ffi::Map var_map, const arith::Analyzer& ana); /*! * \brief Fine grained result of base check. @@ -233,24 +270,40 @@ enum class BaseCheckResult { * * \param base The base struct info. * \param derived The derived struct info. - * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + * + * \sa BaseCheckResult + */ +TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived); +/*! + * \brief Run a base check using a caller-provided analyzer. + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Context analyzer to prove symbolic expression equality. * \return Whether the relation holds. * * \sa BaseCheckResult */ TVM_DLL BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, - arith::Analyzer* ana = nullptr); + const arith::Analyzer& ana); /*! * \brief Check the relation of two struct info to see if one subsumes another one. * * \param base The base struct info. * \param derived The derived struct info. - * \param ana Optional context analyzer to prove symbolic expression equality. + * \return Whether the relation holds. + */ +TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived); +/*! + * \brief Check whether one struct info subsumes another using a caller-provided analyzer. + * \param base The base struct info. + * \param derived The derived struct info. + * \param ana Context analyzer to prove symbolic expression equality. * \return Whether the relation holds. */ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, - arith::Analyzer* ana = nullptr); + const arith::Analyzer& ana); /*! * \brief Return the condition for which base is a superset of derived @@ -279,11 +332,18 @@ TVM_DLL PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const S * * \param lhs The left operand. * \param rhs The right operand. - * \param ana Optional context analyzer to prove symbolic expression equality. + * \return The unified information. + */ +TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs); +/*! + * \brief Unify two struct infos using a caller-provided analyzer. + * \param lhs The left operand. + * \param rhs The right operand. + * \param ana Context analyzer to prove symbolic expression equality. * \return The unified information. */ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, - arith::Analyzer* ana = nullptr); + const arith::Analyzer& ana); /*! * \brief Get the TIR variables that appear in the input struct info. diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index d3853bb9179d..750f181114a4 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -255,7 +255,7 @@ class BlockBuilderNode : public ffi::Object { * \brief Get the analyzer of the BlockBuilder. * \return The BlockBuilder's arithmetic analyzer. */ - virtual arith::Analyzer* GetAnalyzer() = 0; + virtual arith::Analyzer GetAnalyzer() = 0; static constexpr const bool _type_mutable = true; TVM_FFI_DECLARE_OBJECT_INFO("relax.BlockBuilder", BlockBuilderNode, ffi::Object); diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 3ec0b555b5ef..58d46f04380b 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -44,8 +44,9 @@ namespace tvm { namespace arith { +class AnalyzerObj; class Analyzer; -} +} // namespace arith namespace relax { diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index 2ce162d37062..86b34b71352a 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -59,7 +59,7 @@ class BufferAxisHash { * \return The iter var whose extent to be changed */ Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, - arith::Analyzer* analyzer); + const arith::Analyzer& analyzer); /*! * \brief Construct an axis group graph from a PrimFunc. Two buffer axis are connected if they @@ -125,7 +125,7 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { } bool Match(PrimExpr a, PrimExpr buffer_shape_a, PrimExpr b, PrimExpr buffer_shape_b, - arith::Analyzer* analyzer) { + const arith::Analyzer& analyzer) { if (b.as()) { std::swap(a, b); std::swap(buffer_shape_a, buffer_shape_b); @@ -173,7 +173,7 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { ffi::Array another_indices = another_access_pr.second; for (int j = 0; j < static_cast(another_indices.size()); j++) { if (Match(indices[i], buffer->shape[i], another_indices[j], another_buffer->shape[j], - &analyzer)) { + analyzer)) { JoinBufferAxis({buffer, i}, {another_buffer, j}); } } diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index bfbcaa069818..77f8bab5553f 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -75,7 +75,7 @@ TVM_DLL StructInfo Bind(const StructInfo& sinfo, * \return A map of TIR variables to TIR expressions */ TVM_DLL tvm::ffi::Map InferSymbolicVarMap( - const tvm::ffi::Map& binds, arith::Analyzer* analyzer); + const tvm::ffi::Map& binds, const arith::Analyzer& analyzer); /*! * \brief Check if the given StructInfo is for a boolean scalar (tensor of rank 0 with a boolean diff --git a/include/tvm/s_tir/analysis.h b/include/tvm/s_tir/analysis.h index e90fe15ac3bf..b0cf7b38b9d5 100644 --- a/include/tvm/s_tir/analysis.h +++ b/include/tvm/s_tir/analysis.h @@ -90,8 +90,9 @@ const tirx::SBlockNode* FindAnchorBlock(const IRModule& mod); } // namespace tirx namespace arith { +class AnalyzerObj; class Analyzer; -} +} // namespace arith namespace s_tir { @@ -138,7 +139,8 @@ struct MemCpyDetails { * \param analyzer The analyzer with which to check any algebraic expressions * \returns The source and destination regions being copied, if the loop is equivalent to memcpy. */ -TVM_DLL std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer); +TVM_DLL std::optional IdentifyMemCpy(const For& loop, + const arith::Analyzer& analyzer); /*! * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc diff --git a/include/tvm/tirx/analysis.h b/include/tvm/tirx/analysis.h index 1279455c8e2b..a453a3ae5bea 100644 --- a/include/tvm/tirx/analysis.h +++ b/include/tvm/tirx/analysis.h @@ -37,10 +37,6 @@ namespace tvm { -namespace arith { -class Analyzer; -} - namespace tirx { /*! diff --git a/include/tvm/tirx/index_map.h b/include/tvm/tirx/index_map.h index 7d4c6684b118..2191b51b4e62 100644 --- a/include/tvm/tirx/index_map.h +++ b/include/tvm/tirx/index_map.h @@ -36,7 +36,7 @@ namespace tvm { namespace arith { class Analyzer; -} +} // namespace arith } // namespace tvm namespace tvm { @@ -91,21 +91,27 @@ class IndexMapNode : public ffi::Object { IndexMapNode() {} /*! - * \brief Map indices to the output space + * \brief Map indices to the output space using a fresh analyzer. * * \param indices The indices in the input space. Should contain * one value for each variable in `initial_indices`. + * \returns The indices in the output space. Contains one value for + * each expression in `final_indices`. + */ + ffi::Array MapIndices(const ffi::Array& indices) const; + /*! + * \brief Map indices to the output space using an existing analyzer. * - * \param analyzer An optional analyzer to be used to simplify the - * resulting expressions. If null, will use a fresh analyzer. - * + * \param indices The indices in the input space. Should contain + * one value for each variable in `initial_indices`. + * \param analyzer An analyzer to be used to simplify the resulting expressions. * \returns The indices in the output space. Contains one value for * each expression in `final_indices`. */ ffi::Array MapIndices(const ffi::Array& indices, - arith::Analyzer* analyzer) const; + const arith::Analyzer& analyzer) const; - /*! \brief Map a memory range to the output space + /*! \brief Map a memory range to the output space using a fresh analyzer. * * If contiguous memory locations in the input space are not * necessarily contiguous in the output space (e.g. `lambda i: @@ -114,27 +120,44 @@ class IndexMapNode : public ffi::Object { * * \param ranges The ranges in the input space. Should contain one * value for each variable in `initial_indices`. + * \returns The ranges in the output space. Contains one value for + * each expression in `final_indices`. + */ + ffi::Array MapRanges(const ffi::Array& ranges) const; + /*! \brief Map a memory range to the output space using an existing analyzer. * - * \param analyzer An optional analyzer to be used to simplify the - * resulting expressions. If null, will use a fresh analyzer. + * If contiguous memory locations in the input space are not + * necessarily contiguous in the output space (e.g. `lambda i: + * [8*(i%8) + (i//8)]`), then this will return the smallest range + * such that all valid indices are contained within the given range. * + * \param ranges The ranges in the input space. Should contain one + * value for each variable in `initial_indices`. + * \param analyzer An analyzer to be used to simplify the resulting expressions. * \returns The ranges in the output space. Contains one value for * each expression in `final_indices`. */ - ffi::Array MapRanges(const ffi::Array& ranges, arith::Analyzer* analyzer) const; + ffi::Array MapRanges(const ffi::Array& ranges, + const arith::Analyzer& analyzer) const; - /*! \brief Map a buffer shape to the output space + /*! \brief Map a buffer shape to the output space using a fresh analyzer. * * \param shape The buffer shape in the input space. Should contain * one value for each variable in `initial_indices`. + * \returns The buffer shape in the output space. Contains one + * value for each expression in `final_indices`. + */ + ffi::Array MapShape(const ffi::Array& shape) const; + /*! \brief Map a buffer shape to the output space using an existing analyzer. * - * \param analyzer An optional analyzer to be used to simplify the - * resulting expressions. If null, will use a fresh analyzer. - * + * \param shape The buffer shape in the input space. Should contain + * one value for each variable in `initial_indices`. + * \param analyzer An analyzer to be used to simplify the resulting expressions. * \returns The buffer shape in the output space. Contains one * value for each expression in `final_indices`. */ - ffi::Array MapShape(const ffi::Array& shape, arith::Analyzer* analyzer) const; + ffi::Array MapShape(const ffi::Array& shape, + const arith::Analyzer& analyzer) const; /* \brief Map an Tensor according to this index map * @@ -187,15 +210,28 @@ class IndexMap : public ffi::ObjectRef { static IndexMap FromFunc(int ndim, ffi::TypedFunction(ffi::Array)> func, ffi::Optional inverse_index_map = std::nullopt); - /*! \brief Generate the inverse mapping. + /*! \brief Generate the inverse mapping using a fresh analyzer. + * + * The range of the input indices is required in order to ensure + * that the transformation is bijective over the input domain. + * + * If the user has supplied an `inverse_index_map`, that map is + * assumed to be correct and bijective, and is returned. + * \param initial_ranges The ranges of the input indices. + */ + IndexMap Inverse(ffi::Array initial_ranges) const; + /*! \brief Generate the inverse mapping using an existing analyzer. * * The range of the input indices is required in order to ensure * that the transformation is bijective over the input domain. * * If the user has supplied an `inverse_index_map`, that map is * assumed to be correct and bijective, and is returned. + * \param initial_ranges The ranges of the input indices. + * \param analyzer An analyzer to be used while deriving and validating + * the inverse. */ - IndexMap Inverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const; + IndexMap Inverse(ffi::Array initial_ranges, const arith::Analyzer& analyzer) const; /*! \brief Rename the variables in the index map and ensure the names are unique. * @@ -208,17 +244,31 @@ class IndexMap : public ffi::ObjectRef { IndexMap RenameVariables( const std::function(const Var& var)>& f_name_map = nullptr) const; - /*! \brief Generate the inverse mapping. + /*! \brief Generate the inverse mapping using a fresh analyzer. + * + * Determine the inverse, where the output range may contain + * addresses that do not correspond to an address in the input + * range. + * + * \param initial_ranges The ranges of the input indices. + * \return The inverted index map, along with the predicate for + * which the inverse maps to a valid range. + */ + std::pair NonSurjectiveInverse(ffi::Array initial_ranges) const; + /*! \brief Generate the inverse mapping using an existing analyzer. * * Determine the inverse, where the output range may contain * addresses that do not correspond to an address in the input * range. * + * \param initial_ranges The ranges of the input indices. + * \param analyzer An analyzer to be used while deriving the inverse and + * padding predicate. * \return The inverted index map, along with the predicate for * which the inverse maps to a valid range. */ std::pair NonSurjectiveInverse(ffi::Array initial_ranges, - arith::Analyzer* analyzer) const; + const arith::Analyzer& analyzer) const; TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IndexMap, ffi::ObjectRef, IndexMapNode); }; diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 07df5c470bf4..bbf4f906bdb0 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -133,7 +133,7 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { tvm::tirx::ExprDeepEqual expr_equal; bool result = expr_equal(lhs, rhs); if (!result) { - PrimExpr t = tvm::arith::Analyzer().Simplify(lhs - rhs); + PrimExpr t = tvm::arith::Analyzer()->Simplify(lhs - rhs); if (const IntImmNode* i = t.as()) { result = i->value == 0; } diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 23a22359d261..7df01fe8c1b4 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -184,7 +184,7 @@ inline tvm::te::Tensor pad( output_shape.push_back(t->shape[i]); } else { output_shape.push_back( - analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); + analyzer->Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); } } } else { @@ -213,7 +213,7 @@ inline tvm::te::Tensor pad( indices.push_back(ovars[i]); } if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) { - sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); + sel.push_back(analyzer->Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); } if (pad_mode == "edge") { pad_idx.push_back( diff --git a/include/tvm/topi/nn/bnn.h b/include/tvm/topi/nn/bnn.h index e474cff16941..5a3ba871d56b 100644 --- a/include/tvm/topi/nn/bnn.h +++ b/include/tvm/topi/nn/bnn.h @@ -59,7 +59,7 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, auto n = ishape.size(); ffi::Array oshape; for (size_t i = 0; i < n; ++i) { - oshape.push_back(i == static_cast(axis) ? analyzer.Simplify(indexdiv(ishape[i], 32)) + oshape.push_back(i == static_cast(axis) ? analyzer->Simplify(indexdiv(ishape[i], 32)) : ishape[i]); } diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index 52ef33c80249..e6f280c4bcba 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -76,7 +76,7 @@ inline Tensor dilate(const Tensor& x, ffi::Array strides, double dilat ffi::Array out_shape; arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { - out_shape.push_back(analyzer.Simplify((x->shape[i] - 1) * (strides[i] + 1))); + out_shape.push_back(analyzer->Simplify((x->shape[i] - 1) * (strides[i] + 1))); } return tvm::te::compute( diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 69e9aae4840e..3cdb5b03c58a 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -87,9 +87,9 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, pad_after.Set(width_axis, pad_right); arith::Analyzer analyzer; auto out_height = - analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); + analyzer->Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); auto out_width = - analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); + analyzer->Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh"); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw"); @@ -573,10 +573,10 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s // If not, we skip the last window as it would start in the bottom padded region, // we need to minus 1 to get the correct output shape. auto invalid_last = (raw_out - 1) * stride[i] >= data_shape[ii] + pad_head[i]; - auto out_dim = analyzer.Simplify(if_then_else(invalid_last, raw_out - 1, raw_out)); + auto out_dim = analyzer->Simplify(if_then_else(invalid_last, raw_out - 1, raw_out)); out_shape.Set(ii, out_dim); } else { - auto out_dim = analyzer.Simplify(raw_out); + auto out_dim = analyzer->Simplify(raw_out); out_shape.Set(ii, out_dim); } } diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index ed72c08e5a87..a46c1c05b344 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -497,7 +497,7 @@ inline Tensor concatenate(const ffi::Array& inputs, int axis = 0, for (size_t i = 1; i < axis_sizes.size(); ++i) { join_size += axis_sizes[i]; } - join_size = analyzer.Simplify(join_size); + join_size = analyzer->Simplify(join_size); ffi::Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); @@ -733,8 +733,8 @@ inline te::Tensor dynamic_strided_slice_with_axes( ffi::Array out_shape = x->shape; for (size_t i = 0; i < begin.size(); i++) { int axis = static_cast(axes[i]); - PrimExpr new_shape = - analyzer.Simplify(GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound)); + PrimExpr new_shape = analyzer->Simplify( + GetLength(begin[i], end[i], strides[i], out_shape[axis], assume_inbound)); out_shape.Set(axis, new_shape); } @@ -790,7 +790,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array& if (!begin[i]->IsInstance() && !end[i]->IsInstance() && !strides[i]->IsInstance()) { out_shape.push_back( - analyzer.Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound))); + analyzer->Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound))); } else { out_shape.push_back(tvm::tirx::Var("dim")); } @@ -1744,10 +1744,10 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr arith::Analyzer analyzer; PrimExpr num_elem; bool is_all_int = start.dtype().is_int() && stop.dtype().is_int() && step.dtype().is_int(); - if (is_all_int && analyzer.CanProveGreaterEqual(step, 1)) { + if (is_all_int && analyzer->CanProveGreaterEqual(step, 1)) { // fast path for integer arange when step is positive num_elem = tvm::floordiv((stop - start + step - 1), step); - } else if (is_all_int && analyzer.CanProveLess(step, 0)) { + } else if (is_all_int && analyzer->CanProveLess(step, 0)) { // fast path for integer arange when step is negative num_elem = tvm::floordiv((start - stop - step - 1), -step); } else { @@ -1755,7 +1755,7 @@ inline Tensor arange(const PrimExpr& start, const PrimExpr& stop, const PrimExpr num_elem = tvm::cast(DefaultIndexType(), tvm::ceil(tvm::cast(tvm::DataType::Float(32), stop - start) / step)); } - num_elem = analyzer.Simplify(num_elem); + num_elem = analyzer->Simplify(num_elem); return compute( {num_elem}, @@ -1962,13 +1962,12 @@ inline Tensor meta_schedule_layout_transform( for (const PrimExpr& e : src->shape) { iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e)); } - ffi::Array post_transform_shape = index_map->MapShape(src->shape, &analyzer); + ffi::Array post_transform_shape = index_map->MapShape(src->shape, analyzer); return compute( post_transform_shape, - [src, inv = index_map.Inverse(iter_domain, &analyzer), + [src, inv = index_map.Inverse(iter_domain, analyzer), &analyzer](const ffi::Array& indices) -> PrimExpr { - return src( - inv->MapIndices(ffi::Array{indices.begin(), indices.end()}, &analyzer)); + return src(inv->MapIndices(ffi::Array{indices.begin(), indices.end()}, analyzer)); }, name, tag); } diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index ea70c4de3d0f..c3b77a9603ff 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -100,31 +100,17 @@ def __exit__(self, ptype, value, trace): self._fexit() -class Analyzer: +@tvm_ffi.register_object("arith.Analyzer") +class Analyzer(Object): """Integer arithmetic analyzer - This is a stateful analyzer class that can - be used to perform various symbolic integer analysis. + This is a stateful analyzer class that can be used to perform + various symbolic integer analysis. The same analyzer instance can + be passed to FFI APIs to share accumulated facts across calls. """ def __init__(self): - _mod = _ffi_api.CreateAnalyzer() - self._const_int_bound = _mod("const_int_bound") - self._const_int_bound_update = _mod("const_int_bound_update") - self._const_int_bound_is_bound = _mod("const_int_bound_is_bound") - self._bind = _mod("bind") - self._modular_set = _mod("modular_set") - self._simplify = _mod("Simplify") - self._rewrite_simplify = _mod("rewrite_simplify") - self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats") - self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats") - self._canonical_simplify = _mod("canonical_simplify") - self._int_set = _mod("int_set") - self._enter_constraint_context = _mod("enter_constraint_context") - self._can_prove_equal = _mod("can_prove_equal") - self._can_prove = _mod("can_prove") - self._get_enabled_extensions = _mod("get_enabled_extensions") - self._set_enabled_extensions = _mod("set_enabled_extensions") + self.__init_handle_by_constructor__(_ffi_api.Analyzer) def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. @@ -139,7 +125,7 @@ def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: bound : ConstIntBound The result bound """ - return self._const_int_bound(expr) + return _ffi_api.AnalyzerConstIntBound(self, expr) def const_int_bound_is_bound(self, var: tirx.Var) -> bool: """Check if a variable is bound to a range. @@ -154,7 +140,7 @@ def const_int_bound_is_bound(self, var: tirx.Var) -> bool: result : bool Whether the variable is bound to a range. """ - return self._const_int_bound_is_bound(var) + return _ffi_api.AnalyzerConstIntBoundIsBound(self, var) def modular_set(self, expr: tirx.PrimExpr) -> ModularSet: """Find a modular set that expr belongs to. @@ -169,7 +155,7 @@ def modular_set(self, expr: tirx.PrimExpr) -> ModularSet: result : ModularSet The result. """ - return self._modular_set(expr) + return _ffi_api.AnalyzerModularSet(self, expr) def simplify(self, expr: tirx.PrimExpr, steps: int = 2) -> tirx.PrimExpr: """Simplify expression via both rewrite and canonicalization. @@ -189,7 +175,7 @@ def simplify(self, expr: tirx.PrimExpr, steps: int = 2) -> tirx.PrimExpr: result : Expr The result. """ - return self._simplify(expr, steps) + return _ffi_api.AnalyzerSimplify(self, expr, steps) def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: """Simplify expression via rewriting rules. @@ -204,14 +190,14 @@ def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: result : Expr The result. """ - return self._rewrite_simplify(expr) + return _ffi_api.AnalyzerRewriteSimplify(self, expr) @property def rewrite_simplify_stats(self): - return self._get_rewrite_simplify_stats() + return _ffi_api.AnalyzerGetRewriteSimplifyStats(self) def reset_rewrite_simplify_stats(self): - self._reset_rewrite_simplify_stats() + _ffi_api.AnalyzerResetRewriteSimplifyStats(self) def canonical_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: """Simplify expression via canonicalization. @@ -226,7 +212,7 @@ def canonical_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: result : Expr The result. """ - return self._canonical_simplify(expr) + return _ffi_api.AnalyzerCanonicalSimplify(self, expr) def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet]) -> IntSet: """Compute a symbolic IntSet that covers expr for all values in dom_map. @@ -244,7 +230,7 @@ def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet]) -> IntSe result : IntSet The result. """ - return self._int_set(expr, dom_map) + return _ffi_api.AnalyzerIntSet(self, expr, dom_map) def can_prove( self, expr: tirx.PrimExpr, strength: ProofStrength = ProofStrength.DEFAULT @@ -264,7 +250,7 @@ def can_prove( result : Expr The result. """ - return self._can_prove(expr, strength) + return _ffi_api.AnalyzerCanProve(self, expr, strength) def bind( self, @@ -285,7 +271,7 @@ def bind( allow_override : bool Whether to allow overriding an existing binding for the variable. """ - return self._bind(var, expr, allow_override) + return _ffi_api.AnalyzerBind(self, var, expr, allow_override) def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope: """Create a constraint scope. @@ -306,7 +292,7 @@ def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope: x = te.var("x") analyzer = tvm.arith.Analyzer() - with analzyer.constraint_scope(x % 3 == 0): + with analyzer.constraint_scope(x % 3 == 0): # constraint in effect assert analyzer.modular_set(x).coeff == 3 # constraint no longer in effect @@ -314,7 +300,7 @@ def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope: """ def _fenter(): - return self._enter_constraint_context(constraint) + return _ffi_api.AnalyzerEnterConstraintContext(self, constraint) return ConstraintScope(_fenter) @@ -333,7 +319,7 @@ def update(self, var: tirx.Var, info: ConstIntBound, override: bool = False) -> Whether allow override. """ if isinstance(info, ConstIntBound): - self._const_int_bound_update(var, info, override) + _ffi_api.AnalyzerConstIntBoundUpdate(self, var, info, override) else: raise TypeError(f"Do not know how to handle type {type(info)}") @@ -353,12 +339,12 @@ def can_prove_equal(self, lhs: tirx.PrimExpr, rhs: tirx.PrimExpr) -> bool: result: bool Whether we can prove that lhs == rhs """ - return self._can_prove_equal(lhs, rhs) + return _ffi_api.AnalyzerCanProveEqual(self, lhs, rhs) @property def enabled_extensions(self) -> Extension: """Return the currently enabled extensions""" - value = self._get_enabled_extensions() + value = _ffi_api.AnalyzerGetEnabledExtensions(self) return Extension(value) @enabled_extensions.setter @@ -372,4 +358,4 @@ def enabled_extensions(self, flags: int | Extension): The extensions to enable. """ flags = Extension(flags).value - self._set_enabled_extensions(flags) + _ffi_api.AnalyzerSetEnabledExtensions(self, flags) diff --git a/python/tvm/tirx/function.py b/python/tvm/tirx/function.py index fb0e388d73b0..d0b10ca7d0f7 100644 --- a/python/tvm/tirx/function.py +++ b/python/tvm/tirx/function.py @@ -456,35 +456,39 @@ def is_equivalent_to(self, other_map: "IndexMap") -> bool: return True - def map_indices(self, indices: list[PrimExpr]) -> list[PrimExpr]: + def map_indices(self, indices: list[PrimExpr], analyzer=None) -> list[PrimExpr]: """Apply the index map to a set of indices Parameters ---------- indices : List[PrimExpr] The indices to be mapped + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use while simplifying mapped indices. Returns ------- result : List[PrimExpr] The mapped indices """ - return _ffi_api.IndexMapMapIndices(self, indices) + return _ffi_api.IndexMapMapIndices(self, indices, analyzer) - def map_shape(self, shape: list[PrimExpr]) -> list[PrimExpr]: + def map_shape(self, shape: list[PrimExpr], analyzer=None) -> list[PrimExpr]: """Apply the index map to a buffer shape Parameters ---------- shape : List[PrimExpr] The buffer shape to be mapped + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use while simplifying mapped shape expressions. Returns ------- result : List[PrimExpr] The mapped shape """ - return _ffi_api.IndexMapMapShape(self, shape) + return _ffi_api.IndexMapMapShape(self, shape, analyzer) def map_tensor(self, arr_src: Tensor) -> Tensor: """Apply thie index map to transform the layout of the input Tensor @@ -501,7 +505,7 @@ def map_tensor(self, arr_src: Tensor) -> Tensor: """ return _ffi_api.IndexMapMapTensor(self, arr_src) - def inverse(self, shape: list[Range | PrimExpr]) -> "IndexMap": + def inverse(self, shape: list[Range | PrimExpr], analyzer=None) -> "IndexMap": """Return the inverse of the map Throws an error if the function is not bijective. @@ -513,6 +517,8 @@ def inverse(self, shape: list[Range | PrimExpr]) -> "IndexMap": The region over which the inverse should be determined. Used for validating that the mapping is bijective over this range. + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use while deriving and validating the inverse. Returns ------- @@ -522,9 +528,11 @@ def inverse(self, shape: list[Range | PrimExpr]) -> "IndexMap": """ shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape] - return _ffi_api.IndexMapInverse(self, shape) + return _ffi_api.IndexMapInverse(self, shape, analyzer) - def non_surjective_inverse(self, shape: list[Range | PrimExpr]) -> tuple["IndexMap", PrimExpr]: + def non_surjective_inverse( + self, shape: list[Range | PrimExpr], analyzer=None + ) -> tuple["IndexMap", PrimExpr]: """Return the inverse of the map Can be applied to transformations that introduce padding. @@ -535,6 +543,8 @@ def non_surjective_inverse(self, shape: list[Range | PrimExpr]) -> tuple["IndexM The region over which the inverse should be determined. Used for determining the predicate. + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use while deriving the inverse and padding predicate. Returns ------- @@ -555,4 +565,4 @@ def non_surjective_inverse(self, shape: list[Range | PrimExpr]) -> tuple["IndexM """ shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape] - return _ffi_api.IndexMapNonSurjectiveInverse(self, shape) + return _ffi_api.IndexMapNonSurjectiveInverse(self, shape, analyzer) diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 38c699692e7f..45f352c63131 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -34,14 +34,14 @@ namespace tvm { namespace arith { -Analyzer::Analyzer() +AnalyzerObj::AnalyzerObj() : const_int_bound(this), modular_set(this), rewrite_simplify(this), canonical_simplify(this), int_set(this) {} -void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { +void AnalyzerObj::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { PrimExpr new_expr = expr; new_expr = this->canonical_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr); @@ -54,7 +54,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { this->transitive_comparisons.Bind(var, expr, allow_override); } -void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { +void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) { TVM_FFI_ICHECK(range.defined()); if (tirx::is_one(range->extent)) { this->Bind(var, range->min, allow_override); @@ -67,7 +67,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { // skip rewrite simplify } -void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { +void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { // decompose value as symbol * scale + offset int64_t offset = 0; PrimExpr symbol_scale = tirx::make_const(value.dtype(), 0); @@ -117,7 +117,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { } } -void Analyzer::Bind(const ffi::Map& variables, bool allow_override) { +void AnalyzerObj::Bind(const ffi::Map& variables, bool allow_override) { for (const auto& iter : variables) { this->Bind(iter.first, iter.second, allow_override); } @@ -143,7 +143,7 @@ void ConstraintContext::ExitWithScope() { } } -bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { +bool AnalyzerObj::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { if (const auto* ptr = expr.as()) { return ptr->value >= lower_bound; } @@ -152,7 +152,7 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { return false; } -bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { +bool AnalyzerObj::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { if (const auto* ptr = expr.as()) { return ptr->value < upper_bound; } @@ -161,7 +161,7 @@ bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { return false; } -bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { +bool AnalyzerObj::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { const auto* clhs = lhs.as(); const auto* crhs = rhs.as(); if (clhs && crhs) return clhs->value == crhs->value; @@ -171,7 +171,8 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { return CanProve(lhs - rhs == 0); } -bool Analyzer::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, const PrimExpr& shape) { +bool AnalyzerObj::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, + const PrimExpr& shape) { if (this->CanProve(lhs <= shape, ProofStrength::kSymbolicBound)) return true; // no need to do further attempt if shape is already a constant. if (tirx::is_const_int(shape)) return false; @@ -189,7 +190,7 @@ bool Analyzer::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, cons return false; } -bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { +bool AnalyzerObj::CanProve(const PrimExpr& expr, ProofStrength strength) { // Avoid potentially expensive simplification unless required. if (const auto* ptr = expr.as()) { return ptr->value != 0; @@ -233,7 +234,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { return false; } -PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { +PrimExpr AnalyzerObj::Simplify(const PrimExpr& expr, int steps) { PrimExpr res = expr; // Always starts with a canonical simplification, as some structural property @@ -256,100 +257,69 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("arith.CreateAnalyzer", [](ffi::PackedArgs args, ffi::Any* ret) { - using ffi::Function; - using ffi::TypedFunction; - auto self = std::make_shared(); - auto f = [self](std::string name) -> ffi::Function { - if (name == "const_int_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound(args[0].cast()); - }); - } else if (name == "modular_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->modular_set(args[0].cast()); - }); - } else if (name == "const_int_bound_update") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->const_int_bound.Update(args[0].cast(), args[1].cast(), - args[2].cast()); - }); - } else if (name == "const_int_bound_is_bound") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->const_int_bound.IsBound(args[0].cast()); - }); - } else if (name == "Simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - if (args.size() == 1) { - *ret = self->Simplify(args[0].cast()); - } else if (args.size() == 2) { - *ret = self->Simplify(args[0].cast(), args[1].cast()); - } else { - TVM_FFI_THROW(InternalError) << "Invalid size of argument (" << args.size() << ")"; - } - }); - } else if (name == "rewrite_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify(args[0].cast()); - }); - } else if (name == "get_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->rewrite_simplify.GetStatsCounters(); - }); - } else if (name == "reset_rewrite_simplify_stats") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - self->rewrite_simplify.ResetStatsCounters(); - }); - } else if (name == "canonical_simplify") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->canonical_simplify(args[0].cast()); - }); - } else if (name == "int_set") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->int_set(args[0].cast(), args[1].cast>()); - }); - } else if (name == "bind") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - bool allow_override = args.size() >= 3 && args[2].cast(); - if (auto opt_range = args[1].try_cast()) { - self->Bind(args[0].cast(), opt_range.value(), allow_override); - } else { - self->Bind(args[0].cast(), args[1].cast(), allow_override); - } - }); - } else if (name == "can_prove") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int strength = args[1].cast(); - *ret = self->CanProve(args[0].cast(), static_cast(strength)); - }); - } else if (name == "enter_constraint_context") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - // can't use make_shared due to noexcept(false) decl in destructor, - // see https://stackoverflow.com/a/43907314 - auto ctx = std::shared_ptr>( - new With(self.get(), args[0].cast())); - auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; - *ret = ffi::Function::FromPacked(fexit); - }); - } else if (name == "can_prove_equal") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = self->CanProveEqual(args[0].cast(), args[1].cast()); - }); - } else if (name == "get_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - *ret = static_cast(self->rewrite_simplify.GetEnabledExtensions()); - }); - } else if (name == "set_enabled_extensions") { - return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { - int64_t flags = args[0].cast(); - self->rewrite_simplify.SetEnabledExtensions( - static_cast(flags)); - }); - } - return ffi::Function(); - }; - *ret = ffi::TypedFunction(f); - }); + refl::ObjectDef(); + refl::GlobalDef() + .def("arith.Analyzer", []() { return Analyzer(); }) + .def("arith.AnalyzerConstIntBound", + [](Analyzer analyzer, const PrimExpr& expr) { return analyzer->const_int_bound(expr); }) + .def("arith.AnalyzerConstIntBoundUpdate", + [](Analyzer analyzer, const Var& var, const ConstIntBound& info, bool allow_override) { + analyzer->const_int_bound.Update(var, info, allow_override); + }) + .def("arith.AnalyzerConstIntBoundIsBound", + [](Analyzer analyzer, const Var& var) { return analyzer->const_int_bound.IsBound(var); }) + .def("arith.AnalyzerModularSet", + [](Analyzer analyzer, const PrimExpr& expr) { return analyzer->modular_set(expr); }) + .def("arith.AnalyzerSimplify", [](Analyzer analyzer, const PrimExpr& expr, + int steps) { return analyzer->Simplify(expr, steps); }) + .def("arith.AnalyzerRewriteSimplify", + [](Analyzer analyzer, const PrimExpr& expr) { return analyzer->rewrite_simplify(expr); }) + .def("arith.AnalyzerGetRewriteSimplifyStats", + [](Analyzer analyzer) { return analyzer->rewrite_simplify.GetStatsCounters(); }) + .def("arith.AnalyzerResetRewriteSimplifyStats", + [](Analyzer analyzer) { analyzer->rewrite_simplify.ResetStatsCounters(); }) + .def("arith.AnalyzerCanonicalSimplify", + [](Analyzer analyzer, const PrimExpr& expr) { + return analyzer->canonical_simplify(expr); + }) + .def("arith.AnalyzerIntSet", + [](Analyzer analyzer, const PrimExpr& expr, const ffi::Map& dom_map) { + return analyzer->int_set(expr, dom_map); + }) + .def_packed("arith.AnalyzerBind", + [](ffi::PackedArgs args, ffi::Any* ret) { + TVM_FFI_ICHECK(args.size() == 3 || args.size() == 4) + << "AnalyzerBind expects 3 or 4 arguments, but got " << args.size(); + Analyzer analyzer = args[0].cast(); + bool allow_override = args.size() >= 4 && args[3].cast(); + if (auto opt_range = args[2].try_cast()) { + analyzer->Bind(args[1].cast(), opt_range.value(), allow_override); + } else { + analyzer->Bind(args[1].cast(), args[2].cast(), allow_override); + } + }) + .def("arith.AnalyzerCanProve", + [](Analyzer analyzer, const PrimExpr& expr, int strength) { + return analyzer->CanProve(expr, static_cast(strength)); + }) + .def("arith.AnalyzerEnterConstraintContext", + [](Analyzer analyzer, const PrimExpr& constraint) { + // can't use make_shared due to noexcept(false) decl in destructor, + // see https://stackoverflow.com/a/43907314 + auto ctx = std::shared_ptr>( + new With(analyzer, constraint)); + auto fexit = [ctx](ffi::PackedArgs, ffi::Any*) mutable { ctx.reset(); }; + return ffi::Function::FromPacked(fexit); + }) + .def_method("arith.AnalyzerCanProveEqual", &AnalyzerObj::CanProveEqual) + .def("arith.AnalyzerGetEnabledExtensions", + [](Analyzer analyzer) { + return static_cast(analyzer->rewrite_simplify.GetEnabledExtensions()); + }) + .def("arith.AnalyzerSetEnabledExtensions", [](Analyzer analyzer, int64_t flags) { + analyzer->rewrite_simplify.SetEnabledExtensions( + static_cast(flags)); + }); } } // namespace arith diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index 09f6d31ffd20..475a687cd462 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -137,7 +137,7 @@ class BoundDeducer : public ExprFunctor { } // always use relax bound - bool divided = analyzer_.CanProve(floormod(result_, operand) == 0); + bool divided = analyzer_->CanProve(floormod(result_, operand) == 0); result_ = floordiv(result_, operand); // rounding down here @@ -171,7 +171,7 @@ class BoundDeducer : public ExprFunctor { return; } PrimExpr divisor = op->b; - if (analyzer_.CanProveEqual(divisor, 0)) { + if (analyzer_->CanProveEqual(divisor, 0)) { // Skip zero divisor success_ = false; return; @@ -347,7 +347,7 @@ void BoundDeducer::Deduce() { this->VisitExpr(expr_); if (success_) { - result_ = analyzer_.Simplify(result_); + result_ = analyzer_->Simplify(result_); } } @@ -362,7 +362,7 @@ void BoundDeducer::Relax() { // can not be resolved when either `i` or `j` or both are variables with // some Range OR `i` and `j` both should be a single point in IntSet if (comp_op == kEqual && - (!analyzer_.CanProve(b.min() == b.max()) || !analyzer_.CanProve(a.min() == a.max()))) { + (!analyzer_->CanProve(b.min() == b.max()) || !analyzer_->CanProve(a.min() == a.max()))) { success_ = false; return; } diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 0001afbdfec2..a6093a25ba6a 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -83,7 +83,7 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) { * \param analyzer The analyzer * \return whether value fits in dtype */ -bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) { +bool CastIsSafe(DataType dtype, PrimExpr value, AnalyzerObj* analyzer) { if (!IsIndexType(dtype)) { return false; } @@ -156,7 +156,7 @@ class SplitExprNode : public CanonicalExprNode { * \param analyzer The analyzer * \return whether the cast can be safely pushed to children */ - bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { + bool CanPushCastToChildren(DataType dtype, AnalyzerObj* analyzer) const { // cast(dtype, index % upper_factor / lower_factor * scale) == // cast(dtype, index) % upper_factor / lower_factor * scale // iff it is an upcast (dtype.bits >= self.dtype.bits) or all of @@ -334,7 +334,7 @@ class SumExprNode : public CanonicalExprNode { * \param analyzer The analyzer * \return whether the cast can be safely pushed to children */ - bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const { + bool CanPushCastToChildren(DataType dtype, AnalyzerObj* analyzer) const { bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() : base == -(1LL << (dtype.bits() - 1)); // cast(dtype, arg_1 + arg_2 + ... arg_n) == @@ -545,7 +545,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { public: using Rewriter = RewriteSimplifier::Impl; - explicit Impl(Analyzer* parent) : Rewriter(parent) {} + explicit Impl(AnalyzerObj* parent) : Rewriter(parent) {} PrimExpr CanonicalSimplify(PrimExpr expr) { expr = operator()(expr); @@ -1450,7 +1450,7 @@ void CanonicalSimplifier::Update(const Var& var, const PrimExpr& info, bool over impl_->Update(var, info, override); } -CanonicalSimplifier::CanonicalSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} +CanonicalSimplifier::CanonicalSimplifier(AnalyzerObj* parent) : impl_(new Impl(parent)) {} CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; } diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index 6aaef8327003..d88d9fd34df4 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -55,7 +55,7 @@ class AndOfOrs { PrimExpr AsPrimExpr() const; /*! \brief Simplify the internal representation */ - void Simplify(Analyzer* analyzer); + void Simplify(AnalyzerObj* analyzer); private: /*! \brief Internal utility, simplify within each group of expressions @@ -67,7 +67,7 @@ class AndOfOrs { * before = (a == 5) && ((b < 10) || (b > 10)) * after = (a == 5) && ((b != 10) || false) */ - void SimplifyWithinChunks(Analyzer* analyzer); + void SimplifyWithinChunks(AnalyzerObj* analyzer); /*! \brief Internal utility, simplify across groups of expressions * @@ -78,7 +78,7 @@ class AndOfOrs { * before = ((a == 5) || (b <= 10)) && ((a == 5) || (b >= 10)) * after = ((a == 5) || (b == 10)) && ((a == 5) || true) */ - void SimplifyAcrossChunks(Analyzer* analyzer); + void SimplifyAcrossChunks(AnalyzerObj* analyzer); /*! \brief Remove instances of true/false from internal representation * @@ -118,14 +118,14 @@ class AndOfOrs { * If successful, will overwrite the parameters `a` and `b` with the * simplified form. */ - void TrySimplifyOr(Key* a, Key* b, Analyzer* analyzer); + void TrySimplifyOr(Key* a, Key* b, AnalyzerObj* analyzer); /*! \brief Attempt to simplify (a || b) * * If successful, will overwrite the parameters `a` and `b` with the * simplified form. */ - void TrySimplifyAnd(Key* a, Key* b, Analyzer* analyzer); + void TrySimplifyAnd(Key* a, Key* b, AnalyzerObj* analyzer); /*! \brief The internal representation * @@ -246,7 +246,7 @@ PrimExpr AndOfOrs::AsPrimExpr() const { return expr; } -void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { +void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, AnalyzerObj* analyzer) { Key& a = *a_ptr; Key& b = *b_ptr; PrimExpr joint = GetExpr(a) || GetExpr(b); @@ -262,7 +262,7 @@ void AndOfOrs::TrySimplifyOr(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { } } -void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { +void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, AnalyzerObj* analyzer) { Key& a = *a_ptr; Key& b = *b_ptr; PrimExpr joint = GetExpr(a) && GetExpr(b); @@ -278,14 +278,14 @@ void AndOfOrs::TrySimplifyAnd(Key* a_ptr, Key* b_ptr, Analyzer* analyzer) { } } -void AndOfOrs::Simplify(Analyzer* analyzer) { +void AndOfOrs::Simplify(AnalyzerObj* analyzer) { SimplifyWithinChunks(analyzer); RemoveTrueFalse(); SimplifyAcrossChunks(analyzer); RemoveTrueFalse(); } -void AndOfOrs::SimplifyWithinChunks(Analyzer* analyzer) { +void AndOfOrs::SimplifyWithinChunks(AnalyzerObj* analyzer) { for (auto& chunk : chunks_) { for (size_t expr_i = 0; expr_i < chunk.size(); expr_i++) { for (size_t expr_j = expr_i + 1; expr_j < chunk.size(); expr_j++) { @@ -298,7 +298,7 @@ void AndOfOrs::SimplifyWithinChunks(Analyzer* analyzer) { } } -void AndOfOrs::SimplifyAcrossChunks(Analyzer* analyzer) { +void AndOfOrs::SimplifyAcrossChunks(AnalyzerObj* analyzer) { for (size_t i_and = 0; i_and < chunks_.size(); i_and++) { for (size_t j_and = i_and + 1; j_and < chunks_.size(); j_and++) { auto& i_chunk = chunks_[i_and]; @@ -417,7 +417,7 @@ void AndOfOrs::RemoveTrueFalse() { // recursion. class DisableAndOfOrRecursion { public: - explicit DisableAndOfOrRecursion(Analyzer* analyzer) + explicit DisableAndOfOrRecursion(AnalyzerObj* analyzer) : analyzer_(analyzer), cached_flags_(analyzer->rewrite_simplify.GetEnabledExtensions()) { auto new_flags = static_cast( cached_flags_ & (~RewriteSimplifier::kConvertBooleanToAndOfOrs)); @@ -429,13 +429,13 @@ class DisableAndOfOrRecursion { DisableAndOfOrRecursion& operator=(const DisableAndOfOrRecursion&) = delete; private: - Analyzer* analyzer_; + AnalyzerObj* analyzer_; RewriteSimplifier::Extension cached_flags_; }; } // namespace -PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, Analyzer* analyzer) { +PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, AnalyzerObj* analyzer) { DisableAndOfOrRecursion context(analyzer); AndOfOrs repr(analyzer->Simplify(expr)); repr.Simplify(analyzer); diff --git a/src/arith/conjunctive_normal_form.h b/src/arith/conjunctive_normal_form.h index a173ca587cdb..ad0c7dc4736c 100644 --- a/src/arith/conjunctive_normal_form.h +++ b/src/arith/conjunctive_normal_form.h @@ -31,6 +31,7 @@ namespace tvm { namespace arith { +class AnalyzerObj; class Analyzer; /*! \brief Convert boolean expression to AND of ORs and simplify @@ -41,7 +42,7 @@ class Analyzer; * * \return The simplified expression */ -PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, Analyzer* analyzer); +PrimExpr SimplifyAsAndOfOrs(const PrimExpr& expr, AnalyzerObj* analyzer); } // namespace arith } // namespace tvm diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index c1dc4826f799..fa2d7254ee38 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -94,7 +94,7 @@ struct ConstIntBoundAnalyzer::Entry { class ConstIntBoundAnalyzer::Impl : public ExprFunctor { public: - explicit Impl(Analyzer* parent) : parent_(parent) {} + explicit Impl(AnalyzerObj* parent) : parent_(parent) {} /*! \brief additional bound info about expr in bound */ struct BoundInfo { /*! \brief The expr */ @@ -500,7 +500,7 @@ class ConstIntBoundAnalyzer::Impl private: friend class ConstIntBoundAnalyzer; // parent analyzer - Analyzer* parent_; + AnalyzerObj* parent_; // internal variable map std::unordered_map var_map_; // additional bound info @@ -852,7 +852,7 @@ std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con return impl_->EnterConstraint(constraint); } -ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(AnalyzerObj* parent) : impl_(new Impl(parent)) {} ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index bb74eae7cb92..d7a4874de0b3 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -215,7 +215,7 @@ bool DetectClipBound(const PrimExpr& cond, LinearEqEntry ret; Analyzer analyzer; if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; - ret.coeff = analyzer.Simplify(ret.coeff); + ret.coeff = analyzer->Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; ffi::Optional min_value; @@ -268,7 +268,7 @@ void SplitCommExpr(const PrimExpr& e, std::vector* ret) { ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars) { std::vector splits; Analyzer analyzer; - SplitCommExpr(analyzer.Simplify(e), &splits); + SplitCommExpr(analyzer->Simplify(e), &splits); std::unordered_map rmap; for (Var v : vars) { rmap[v.get()] = IntervalEntry(); @@ -280,10 +280,10 @@ ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& v for (Var v : vars) { IntervalEntry e = rmap[v.get()]; if (e.min_value.defined()) { - e.min_value = analyzer.Simplify(e.min_value); + e.min_value = analyzer->Simplify(e.min_value); } if (e.max_value.defined()) { - e.max_value = analyzer.Simplify(e.max_value); + e.max_value = analyzer->Simplify(e.max_value); } ret.push_back(e.min_value); ret.push_back(e.max_value); diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index 977ea779f450..6701beee3d69 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -125,7 +125,7 @@ class BufferTouchedDomain final : public IRVisitorWithAnalyzer { if (args[i].as()) { (*bounds)[i].emplace_back(IntSet::Vector(args[i])); } else { - (*bounds)[i].emplace_back(analyzer_.int_set(args[i])); + (*bounds)[i].emplace_back(analyzer_->int_set(args[i])); } } } diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index a6b26d16cdda..8a24d262e4fc 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -94,7 +94,7 @@ IntGroupBounds IntGroupBounds::FromRange(const Range& r) { equal.push_back(r->min); } else { lower.push_back(r->min); - upper.push_back(analyzer.Simplify(r->min + r->extent - 1)); + upper.push_back(analyzer->Simplify(r->min + r->extent - 1)); } return IntGroupBounds(coef, lower, equal, upper); } @@ -106,10 +106,10 @@ IntGroupBounds IntGroupBounds::operator+(const Range& r) { ffi::Array upper; const PrimExpr& coef = operator->()->coef; if (tirx::is_one(r->extent)) { - equal.push_back(analyzer.Simplify(r->min * coef)); + equal.push_back(analyzer->Simplify(r->min * coef)); } else { - lower.push_back(analyzer.Simplify(r->min * coef)); - upper.push_back(analyzer.Simplify((r->min + r->extent - 1) * coef)); + lower.push_back(analyzer->Simplify(r->min * coef)); + upper.push_back(analyzer->Simplify((r->min + r->extent - 1) * coef)); } for (const auto& eq : operator->()->equal) equal.push_back(eq); for (const auto& lb : operator->()->lower) lower.push_back(lb); @@ -127,7 +127,7 @@ IntGroupBounds IntGroupBounds::Substitute(const ffi::Map& subst) Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) const { Analyzer analyzer; - analyzer.Bind(vranges_addl); + analyzer->Bind(vranges_addl); std::unordered_map var_intsets; for (auto kv : vranges_addl) { @@ -147,7 +147,7 @@ Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) co } if (lowers.size() == 1 && uppers.size() == 1 && tirx::is_one(coef)) { - return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); + return Range(analyzer->Simplify(lowers[0]), analyzer->Simplify(uppers[0] + 1)); } // Here we will try all pairs of lower and upper bounds and find the best pair, that is, the @@ -163,22 +163,22 @@ Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) co for (const PrimExpr& upp : uppers) { // Since diff may depend on some other variables, we compute its overapproximation ffi::Optional diff_over; - PrimExpr diff_1 = analyzer.Simplify(floordiv(upp - low, coef), 3); + PrimExpr diff_1 = analyzer->Simplify(floordiv(upp - low, coef), 3); IntSet diff_set1 = EvalSet(diff_1, var_intsets); if (diff_set1.HasUpperBound()) { - diff_over = analyzer.Simplify(diff_set1.max(), 3); + diff_over = analyzer->Simplify(diff_set1.max(), 3); } // low is the lower bound for v*coef, but we need the lower bound for v. // We use rounding-up division to compute it. Since we want to use a single formula - PrimExpr low_divided = analyzer.Simplify(floordiv(low + coef - 1, coef), 3); + PrimExpr low_divided = analyzer->Simplify(floordiv(low + coef - 1, coef), 3); // Compute another difference which may be more precise (or not). - PrimExpr diff_2 = analyzer.Simplify(floordiv(upp, coef) - low_divided, 3); + PrimExpr diff_2 = analyzer->Simplify(floordiv(upp, coef) - low_divided, 3); IntSet diff_set2 = EvalSet(diff_2, var_intsets); if (diff_set2.HasUpperBound()) { - PrimExpr diff_over_2 = analyzer.Simplify(diff_set2.max(), 3); - diff_over = diff_over.defined() ? (analyzer.CanProve(diff_over_2 - diff_over.value() < 0) + PrimExpr diff_over_2 = analyzer->Simplify(diff_set2.max(), 3); + diff_over = diff_over.defined() ? (analyzer->CanProve(diff_over_2 - diff_over.value() < 0) ? diff_over_2 : diff_over.value()) : diff_over_2; @@ -187,7 +187,7 @@ Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) co // If it is provable that the new one is strictly better than the current best one, // then replace it. Note that we are biased towards earlier pairs which should be simpler. if (diff_over.defined() && (!best_diff_over.defined() || - analyzer.CanProve(diff_over.value() - best_diff_over < 0))) { + analyzer->CanProve(diff_over.value() - best_diff_over < 0))) { best_lower = low_divided; best_diff_over = diff_over.value(); } @@ -198,7 +198,7 @@ Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) co TVM_FFI_ICHECK(!best_diff_over.defined()); return Range(); } - return Range::FromMinExtent(best_lower, analyzer.Simplify(best_diff_over + 1)); + return Range::FromMinExtent(best_lower, analyzer->Simplify(best_diff_over + 1)); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -271,15 +271,15 @@ IntConstraintsTransform IntConstraintsTransform::operator+( ffi::Map src_to_dst; Analyzer ana_first; - ana_first.Bind(operator->()->src->ranges); + ana_first->Bind(operator->()->src->ranges); for (auto p : other->dst_to_src) { - dst_to_src.Set(p.first, ana_first.Simplify(Substitute(p.second, operator->()->dst_to_src))); + dst_to_src.Set(p.first, ana_first->Simplify(Substitute(p.second, operator->()->dst_to_src))); } Analyzer ana_second; - ana_second.Bind(other->dst->ranges); + ana_second->Bind(other->dst->ranges); for (auto p : operator->()->src_to_dst) { - src_to_dst.Set(p.first, ana_second.Simplify(Substitute(p.second, other->src_to_dst))); + src_to_dst.Set(p.first, ana_second->Simplify(Substitute(p.second, other->src_to_dst))); } return IntConstraintsTransform(operator->()->src, other->dst, src_to_dst, dst_to_src); } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 86a2d949bc12..d16a6bc7b58d 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -70,7 +70,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("arith.IntervalSet", MakeIntervalSet); } -IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +IntervalSet Intersect(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b) { PrimExpr max_value = min(a->max_value, b->max_value); PrimExpr min_value = max(a->min_value, b->min_value); if ((max_value.dtype().is_int() || max_value.dtype().is_uint()) && @@ -82,7 +82,7 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } } -IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { +IntervalSet Union(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b) { if (a->IsEmpty()) return b; if (b->IsEmpty()) return a; PrimExpr max_value = max(a->max_value, b->max_value); @@ -121,7 +121,7 @@ TVM_DECLARE_LOGICAL_OP(Not); * \note this can possibly relax the set. */ template -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { +inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { DataType dtype = op->dtype; if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr expr; @@ -143,7 +143,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, con } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analyer, IntervalSet a, IntervalSet b, const tirx::AddNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); @@ -158,7 +158,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analyer, IntervalSet a, IntervalSet b, const tirx::SubNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); @@ -173,7 +173,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const tirx::MulNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); @@ -207,7 +207,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interva } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const tirx::DivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); @@ -241,7 +241,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interva } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const tirx::ModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); @@ -270,7 +270,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interva } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const tirx::FloorDivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); @@ -304,7 +304,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, In } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, const tirx::FloorModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); @@ -365,7 +365,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, In } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analzyer, IntervalSet a, IntervalSet b, const tirx::MaxNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); @@ -376,7 +376,7 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interva } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, +inline IntervalSet Combine(AnalyzerObj* analzyer, IntervalSet a, IntervalSet b, const tirx::MinNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); @@ -401,7 +401,7 @@ using namespace tirx; // We might use better set analysis in the future to replace the intervalset. class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, const ffi::Map& dom_map, + IntervalSetEvaluator(AnalyzerObj* analyzer, const ffi::Map& dom_map, const std::vector>* dom_constraints = nullptr, bool eval_vec = false) : analyzer_(analyzer), @@ -633,7 +633,7 @@ class IntervalSetEvaluator : public ExprFunctor { // Variables currently being relaxed, used to break cyclic dependencies. std::unordered_set relax_in_progress_; // analyzer - Analyzer* analyzer_; + AnalyzerObj* analyzer_; const ffi::Map& dom_map_; const std::vector>* dom_constraints_; bool eval_vec_{false}; @@ -641,7 +641,7 @@ class IntervalSetEvaluator : public ExprFunctor { class IntSetAnalyzer::Impl { public: - explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} + explicit Impl(AnalyzerObj* analyzer) : analyzer_(analyzer) {} IntSet Eval(const PrimExpr& expr, const ffi::Map& dom_map) const { return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); @@ -665,7 +665,7 @@ class IntSetAnalyzer::Impl { static std::vector> DetectBoundInfo(const PrimExpr& cond); // The parent arith::Analyzer - Analyzer* analyzer_; + AnalyzerObj* analyzer_; // Map of variables to global variable bounds (e.g. loop iterator // ranges) @@ -678,7 +678,7 @@ class IntSetAnalyzer::Impl { std::vector> dom_constraints_; }; -IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} +IntSetAnalyzer::IntSetAnalyzer(AnalyzerObj* parent) : impl_(new Impl(parent)) {} IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; } @@ -781,8 +781,8 @@ Range IntSet::CoverRange(Range max_range) const { const IntervalSetNode* s_int = (*this).as(); TVM_FFI_ICHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { - return Range::FromMinExtent(analyzer.Simplify(s_int->min_value), - analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); + return Range::FromMinExtent(analyzer->Simplify(s_int->min_value), + analyzer->Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; } @@ -814,7 +814,7 @@ bool IntSet::IsSinglePoint() const { return (s_int && s_int->IsSinglePoint()); } -bool IntSet::CanProveSinglePoint(Analyzer* ana) const { +bool IntSet::CanProveSinglePoint(const Analyzer& ana) const { const IntervalSetNode* s_int = (*this).as(); if (!s_int) return false; if (s_int->IsSinglePoint()) return true; @@ -824,19 +824,19 @@ bool IntSet::CanProveSinglePoint(Analyzer* ana) const { bool IntSet::CanProvePositive() const { Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value))); + return (s_int && is_positive_const(analyzer->Simplify(s_int->min_value))); } bool IntSet::CanProveNegative() const { Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value))); + return (s_int && is_negative_const(analyzer->Simplify(s_int->max_value))); } bool IntSet::CanProveNonPositive() const { Analyzer analyzer; if (const auto* s_int = (*this).as()) { - auto max = analyzer.Simplify(s_int->max_value); + auto max = analyzer->Simplify(s_int->max_value); return is_zero(max) || is_negative_const(max); } return false; @@ -845,7 +845,7 @@ bool IntSet::CanProveNonPositive() const { bool IntSet::CanProveNonNegative() const { Analyzer analyzer; if (const IntervalSetNode* s_int = (*this).as()) { - auto min = analyzer.Simplify(s_int->min_value); + auto min = analyzer->Simplify(s_int->min_value); return is_zero(min) || is_positive_const(min); } return false; @@ -896,7 +896,7 @@ IntSet IntSet::Interval(PrimExpr min, PrimExpr max) { } // Range related code -inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) { +inline bool ProveEqual(AnalyzerObj* analyzer, PrimExpr lhs, PrimExpr rhs) { return is_zero(analyzer->Simplify(lhs - rhs)); } @@ -921,8 +921,8 @@ bool IntSet::MatchRange(const Range& b) const { if (!a_int) return false; if (!a_int->HasUpperBound() || !a_int->HasLowerBound()) return false; Analyzer ana; - return ProveEqual(&ana, a_int->min_value, b->min) && - ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); + return ProveEqual(ana.get(), a_int->min_value, b->min) && + ProveEqual(ana.get(), a_int->max_value, b->extent + b->min - 1); } IntSet Union(const ffi::Array& sets) { @@ -931,9 +931,9 @@ IntSet Union(const ffi::Array& sets) { Analyzer ana; IntervalSet x = ToIntervalSet(sets[0]); for (size_t i = 1; i < sets.size(); ++i) { - x = Union(&ana, x, ToIntervalSet(sets[i])); + x = Union(ana.get(), x, ToIntervalSet(sets[i])); } - return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); + return IntervalSet(ana->Simplify(x->min_value), ana->Simplify(x->max_value)); } ffi::Array UnionRegion(const ffi::Array>& nd_int_sets) { @@ -974,9 +974,9 @@ IntSet UnionLowerBound(const ffi::Array& sets) { continue; } bool bound_1 = is_neg_inf(new_min_inclusive) || is_pos_inf(max_inclusive) || - analyzer.CanProve(new_min_inclusive <= max_inclusive + 1); + analyzer->CanProve(new_min_inclusive <= max_inclusive + 1); bool bound_2 = is_neg_inf(min_inclusive) || is_pos_inf(new_max_inclusive) || - analyzer.CanProve(min_inclusive <= new_max_inclusive + 1); + analyzer->CanProve(min_inclusive <= new_max_inclusive + 1); if (bound_1 && bound_2) { min_inclusive = min(min_inclusive, new_min_inclusive); max_inclusive = max(max_inclusive, new_max_inclusive); @@ -1014,9 +1014,9 @@ IntSet Intersect(const ffi::Array& sets) { Analyzer ana; IntervalSet x = ToIntervalSet(sets[0]); for (size_t i = 1; i < sets.size(); ++i) { - x = Intersect(&ana, x, ToIntervalSet(sets[i])); + x = Intersect(ana.get(), x, ToIntervalSet(sets[i])); } - return IntervalSet(ana.Simplify(x->min_value), ana.Simplify(x->max_value)); + return IntervalSet(ana->Simplify(x->min_value), ana->Simplify(x->max_value)); } ffi::Map ConvertDomMap(const ffi::Map& dom_map) { @@ -1037,7 +1037,7 @@ ffi::Map ConvertDomMap(const std::unordered_map& dom_map) { Analyzer ana; - return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e); + return IntervalSetEvaluator(ana.get(), dom_map, {}, false).Eval(e); } IntSet IntSet::Vector(PrimExpr x) { @@ -1048,7 +1048,7 @@ IntSet IntSet::Vector(PrimExpr x) { // vector case. Analyzer ana; ffi::Map dmap; - return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); + return IntervalSetEvaluator(ana.get(), dmap, {}, true).Eval(x); } } @@ -1062,13 +1062,13 @@ IntSet EvalSet(PrimExpr e, const std::unordered_map& dom IntSet EvalSet(Range r, const ffi::Map& dom_map) { Analyzer ana; - if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana.CanProveEqual(r->extent, 1)) { + if ((r->min->dtype.is_int() || r->min->dtype.is_uint()) && ana->CanProveEqual(r->extent, 1)) { return EvalSet(r->min, dom_map); } - IntervalSetEvaluator m(&ana, dom_map); + IntervalSetEvaluator m(ana.get(), dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; - auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); + auto res = m.Eval(IntervalSet(r->min, ana->Simplify(sum))); return res; } @@ -1078,12 +1078,12 @@ IntSet EvalSet(Range r, const std::unordered_map& dom_ma ffi::Array EvalSet(const ffi::Array& region, const ffi::Map& dom_map) { Analyzer ana; - IntervalSetEvaluator m(&ana, dom_map); + IntervalSetEvaluator m(ana.get(), dom_map); ffi::Array result; result.reserve(region.size()); for (const Range& r : region) { PrimExpr sum = r->min + (r->extent - 1); - result.push_back(m.Eval(IntervalSet(r->min, ana.Simplify(sum)))); + result.push_back(m.Eval(IntervalSet(r->min, ana->Simplify(sum)))); } return result; } @@ -1091,7 +1091,7 @@ ffi::Array EvalSet(const ffi::Array& region, const ffi::Map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); - IntervalSetEvaluator m(&ana, dmap); + IntervalSetEvaluator m(ana.get(), dmap); const IntervalSetNode* s_int = s.as(); PrimExpr vmax = s_int->HasUpperBound() ? m.Eval(s_int->max_value).max() : s_int->max_value; PrimExpr vmin = s_int->HasLowerBound() ? m.Eval(s_int->min_value).min() : s_int->min_value; @@ -1100,7 +1100,7 @@ IntSet EvalSet(IntSet s, const std::unordered_map& dom_m class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntervalSetEvaluator(Analyzer* analyzer, const ffi::Map& dom_map) + explicit SubExprIntervalSetEvaluator(AnalyzerObj* analyzer, const ffi::Map& dom_map) : IntervalSetEvaluator(analyzer, dom_map) {} IntervalSet VisitExpr(const PrimExpr& n) final { @@ -1116,7 +1116,7 @@ ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e, const std::unordered_map& dom_map) { Analyzer ana; auto dmap = ConvertDomMap(dom_map); - SubExprIntervalSetEvaluator m(&ana, dmap); + SubExprIntervalSetEvaluator m(ana.get(), dmap); m.Eval(e); return m.expr_map; } @@ -1137,7 +1137,7 @@ ffi::Map AsIntSet(const ffi::Map& var_dom) { /*! \brief Helper function to convert IterSumExpr to the actual touched range. */ static ffi::Optional EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent, - Analyzer* analyzer) { + AnalyzerObj* analyzer) { if (analyzer->CanProve(extent == 0)) { return IntSet::Nothing(); } @@ -1172,7 +1172,8 @@ static ffi::Optional EvalIterSum(const IterSumExpr& iter_min, const Prim ffi::Optional> EstimateRegionStrictBound(const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, - Analyzer* analyzer) { + const Analyzer& analyzer) { + AnalyzerObj* analyzer_ptr = analyzer.get(); int ndim = region.size(); ffi::Array iter_sum_exprs{nullptr}; { @@ -1199,7 +1200,7 @@ ffi::Optional> EstimateRegionStrictBound(const ffi::Array int_set = EvalIterSum(sum_expr, range->extent, analyzer); + ffi::Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer_ptr); if (int_set.defined()) { result.push_back(int_set.value()); } else { @@ -1212,13 +1213,14 @@ ffi::Optional> EstimateRegionStrictBound(const ffi::Array> EstimateRegionLowerBound(const ffi::Array& region, const ffi::Map& var_dom, const PrimExpr& predicate, - arith::Analyzer* analyzer) { + const Analyzer& analyzer) { return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); } ffi::Array EstimateRegionUpperBound(const ffi::Array& region, const ffi::Map& var_dom, - const PrimExpr& predicate, Analyzer* analyzer) { + const PrimExpr& predicate, const Analyzer& analyzer) { + AnalyzerObj* analyzer_ptr = analyzer.get(); if (ffi::Optional> result = EstimateRegionStrictBound( /*region=*/region, /*var_dom=*/var_dom, @@ -1244,7 +1246,7 @@ ffi::Array EstimateRegionUpperBound(const ffi::Array& region, extent = relaxed.max(); } - if (ffi::Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer)) { + if (ffi::Optional int_set = EvalIterSum(sum_expr, range->extent, analyzer_ptr)) { result.push_back(int_set.value()); continue; } @@ -1271,19 +1273,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::Array region, ffi::Map var_dom, PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; - return EstimateRegionLowerBound(region, var_dom, predicate, &analyzer); + return EstimateRegionLowerBound(region, var_dom, predicate, analyzer); }) .def("arith.EstimateRegionStrictBound", [](ffi::Array region, ffi::Map var_dom, PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; - return EstimateRegionStrictBound(region, var_dom, predicate, &analyzer); + return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); }) .def("arith.EstimateRegionUpperBound", [](ffi::Array region, ffi::Map var_dom, PrimExpr predicate) -> ffi::Optional> { Analyzer analyzer; - return EstimateRegionUpperBound(region, var_dom, predicate, &analyzer); + return EstimateRegionUpperBound(region, var_dom, predicate, analyzer); }) .def("arith.PosInf", []() { return SymbolicLimits::pos_inf_; }) .def("arith.NegInf", []() { return SymbolicLimits::neg_inf_; }) diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index c88239939623..72820c178e6b 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -122,7 +122,7 @@ class IntervalSet : public IntSet { * \param b The second set. * \return The result set. */ -TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); +TVM_DLL IntervalSet Union(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b); /*! * \brief Create insersection of two IntervalSets. @@ -131,7 +131,7 @@ TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); * \param b The second set. * \return The result set. */ -TVM_DLL IntervalSet Intersect(Analyzer* analzyer, IntervalSet a, IntervalSet b); +TVM_DLL IntervalSet Intersect(AnalyzerObj* analzyer, IntervalSet a, IntervalSet b); } // namespace arith } // namespace tvm diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 39b7faad84fd..6fbd4c4551f3 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -135,13 +135,14 @@ void CollectDerivedConstraintFacts(const PrimExpr& condition, std::vector* constraints, Analyzer* analyzer, +void EnterConstraintFacts(WithGroup* constraints, AnalyzerObj* analyzer, const PrimExpr& condition) { - constraints->Emplace(analyzer, condition); + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); + constraints->Emplace(analyzer_ref, condition); std::vector derived; CollectDerivedConstraintFacts(condition, &derived); for (const PrimExpr& fact : derived) { - constraints->Emplace(analyzer, fact); + constraints->Emplace(analyzer_ref, fact); } } @@ -163,8 +164,9 @@ ffi::Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext( pred = pred && val; } int n = indices.size(); + arith::Analyzer analyzer_ref = ffi::GetRef(this->analyzer_); ffi::Array simplified = arith::IterMapSimplify( - indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, this->analyzer_); + indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective, analyzer_ref); if (non_trivial_only) { for (int i = 0; i < n; ++i) { if (simplified[i]->IsInstance() && indices[i]->IsInstance()) { diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index e15d121cfad9..5e2fa6ab0006 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -47,7 +47,7 @@ namespace arith { */ class IRMutatorWithAnalyzer : public tirx::StmtExprMutator { public: - explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} + explicit IRMutatorWithAnalyzer(AnalyzerObj* analyzer) : analyzer_(analyzer) {} using StmtExprMutator::VisitExpr_; using StmtExprMutator::VisitStmt_; @@ -82,7 +82,7 @@ class IRMutatorWithAnalyzer : public tirx::StmtExprMutator { bool non_trivial_only); /*! \brief internal analyzer field. */ - Analyzer* analyzer_; + AnalyzerObj* analyzer_; /*! \brief Scope stack for accumulated assert constraints. */ ScopeStack> constraint_scope_; // the following two fields are useful in case we want diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 35f2bec52919..7c0f2458fa77 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -34,7 +34,7 @@ using namespace tirx; void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { constraint_scope_.WithNewScope([&]() { - analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); StmtExprVisitor::VisitStmt_(op); }); } @@ -42,7 +42,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { void IRVisitorWithAnalyzer::VisitStmt_(const SBlockNode* op) { constraint_scope_.WithNewScope([&]() { for (const auto& iter_var : op->iter_vars) { - analyzer_.Bind(iter_var->var, iter_var->dom); + analyzer_->Bind(iter_var->var, iter_var->dom); } StmtExprVisitor::VisitStmt_(op); }); @@ -50,7 +50,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const SBlockNode* op) { void IRVisitorWithAnalyzer::VisitStmt_(const BindNode* op) { this->VisitExpr(op->value); - analyzer_.Bind(op->var, op->value); + analyzer_->Bind(op->var, op->value); } void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { @@ -60,13 +60,13 @@ void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { PrimExpr real_condition = ExtractRealCondition(op->condition); constraint_scope_.WithNewScope([&]() { - constraint_scope_.Current().Emplace(&analyzer_, real_condition); + constraint_scope_.Current().Emplace(analyzer_, real_condition); this->VisitStmt(op->then_case); }); if (op->else_case) { constraint_scope_.WithNewScope([&]() { - constraint_scope_.Current().Emplace(&analyzer_, - analyzer_.rewrite_simplify(Not(real_condition))); + constraint_scope_.Current().Emplace(analyzer_, + analyzer_->rewrite_simplify(Not(real_condition))); this->VisitStmt(op->else_case.value()); }); } @@ -78,7 +78,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); + analyzer_->Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); } StmtExprVisitor::VisitStmt_(op); }); @@ -86,7 +86,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { void IRVisitorWithAnalyzer::VisitStmt_(const AssertStmtNode* op) { this->VisitExpr(op->condition); - constraint_scope_.Current().Emplace(&analyzer_, op->condition); + constraint_scope_.Current().Emplace(analyzer_, op->condition); } void IRVisitorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) { @@ -101,11 +101,11 @@ void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { PrimExpr cond = op->args[0]; this->VisitExpr(op->args[0]); constraint_scope_.WithNewScope([&]() { - constraint_scope_.Current().Emplace(&analyzer_, cond); + constraint_scope_.Current().Emplace(analyzer_, cond); this->VisitExpr(op->args[1]); }); constraint_scope_.WithNewScope([&]() { - constraint_scope_.Current().Emplace(&analyzer_, analyzer_.rewrite_simplify(Not(cond))); + constraint_scope_.Current().Emplace(analyzer_, analyzer_->rewrite_simplify(Not(cond))); this->VisitExpr(op->args[2]); }); } else { @@ -115,13 +115,13 @@ void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { void IRVisitorWithAnalyzer::VisitExpr_(const LetNode* op) { this->VisitExpr(op->value); - analyzer_.Bind(op->var, op->value); + analyzer_->Bind(op->var, op->value); this->VisitExpr(op->body); } void IRVisitorWithAnalyzer::VisitExpr_(const ReduceNode* op) { for (const IterVar& iv : op->axis) { - analyzer_.Bind(iv->var, iv->dom); + analyzer_->Bind(iv->var, iv->dom); } StmtExprVisitor::VisitExpr_(op); } diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index 24728c69e19c..55131d6a20c9 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -36,7 +36,7 @@ namespace arith { class IRVisitorWithAnalyzer : public tirx::StmtExprVisitor { public: - PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } + PrimExpr Simplify(const PrimExpr& expr) { return analyzer_->Simplify(expr); } using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 2f9111a0c03a..8623efa1a64d 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -174,7 +174,7 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const ffi::Map& input_iters, + explicit IterMapRewriter(AnalyzerObj* analyzer, const ffi::Map& input_iters, IterMapLevel check_level, bool simplify_trivial_iterators, ffi::Array* errors) : analyzer_(analyzer), @@ -431,7 +431,7 @@ class IterMapRewriter : public ExprMutator { }; // Internal analyzer - Analyzer* analyzer_; + AnalyzerObj* analyzer_; // Iter map check level IterMapLevel check_level_; // Error messages for each unresolved expression. @@ -1369,8 +1369,8 @@ bool MatchBoundConstraints(PrimExpr pred, ffi::Map* input_iters, }; f_extract(sum_parts, true); arith::Analyzer analyzer; - lhs_expr = analyzer.Simplify(lhs_expr); - rhs_expr = analyzer.Simplify(rhs_expr); + lhs_expr = analyzer->Simplify(lhs_expr); + rhs_expr = analyzer->Simplify(rhs_expr); } ffi::Optional lower_bound = std::nullopt, upper_bound = std::nullopt; PrimExpr iter; @@ -1430,8 +1430,9 @@ bool IterRangeSanityCheck(const ffi::Map& iter_ranges) { IterMapResult DetectIterMap(const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& predicate, - IterMapLevel check_level, arith::Analyzer* analyzer, + IterMapLevel check_level, const arith::Analyzer& analyzer, bool simplify_trivial_iterators) { + arith::AnalyzerObj* analyzer_ptr = analyzer.get(); IterMapResult result; // Overall detection algorithm is divided into two steps: @@ -1459,7 +1460,7 @@ IterMapResult DetectIterMap(const ffi::Array& indices, constraints.begin(), constraints.end(), [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; }); - IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level, + IterMapRewriter rewriter(analyzer_ptr, constrained_input_iters, check_level, simplify_trivial_iterators, &result->errors); // Step0.0: rewrite constraints in the order from size-small ones to size-big ones for (const IterConstraint& constraint : constraints) { @@ -1521,13 +1522,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, + return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), ana, simplify_trivial_iterators); }); } IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input_iters, - arith::Analyzer* analyzer) { + const arith::Analyzer& analyzer) { + arith::AnalyzerObj* analyzer_ptr = analyzer.get(); IterMapResult result; TVM_FFI_ICHECK(IterRangeSanityCheck(input_iters)) << "Invalid iterators. Iterators may not be expressions of each other."; @@ -1536,7 +1538,7 @@ IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input std::vector constraints; IterMapLevel check_level = IterMapLevel::NoCheck; bool simplify_trivial_iterators = true; - IterMapRewriter rewriter(analyzer, input_iters, check_level, simplify_trivial_iterators, + IterMapRewriter rewriter(analyzer_ptr, input_iters, check_level, simplify_trivial_iterators, &result->errors); return rewriter.RewriteToNormalizedIterSum(index); @@ -1547,7 +1549,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("arith.NormalizeToIterSum", [](PrimExpr index, const ffi::Map& input_iters) { arith::Analyzer ana; - return NormalizeToIterSum(index, input_iters, &ana); + return NormalizeToIterSum(index, input_iters, ana); }); } @@ -1696,7 +1698,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o } /*! \brief Find approximate least common multiplier. */ -PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyzer* analyzer) { +PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, AnalyzerObj* analyzer) { auto fsplit = [](const PrimExpr& e) -> std::pair { if (const IntImmNode* imm = e.as()) { return {1, imm->value}; @@ -2067,7 +2069,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { */ class IterMapToExprNormalizer : public ExprMutator { public: - explicit IterMapToExprNormalizer(Analyzer* analyzer) : analyzer_(analyzer) {} + explicit IterMapToExprNormalizer(AnalyzerObj* analyzer) : analyzer_(analyzer) {} PrimExpr Convert(const PrimExpr& expr) { return VisitExpr(expr); } @@ -2119,7 +2121,7 @@ class IterMapToExprNormalizer : public ExprMutator { } private: - Analyzer* analyzer_; + AnalyzerObj* analyzer_; }; bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs) { @@ -2141,7 +2143,7 @@ bool IterMapRewriter::CanProveDivisible(const PrimExpr& lhs, const PrimExpr& rhs PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr) { arith::Analyzer analyzer; - IterMapToExprNormalizer normalizer(&analyzer); + IterMapToExprNormalizer normalizer(analyzer.get()); return normalizer.Convert(expr); } @@ -2153,7 +2155,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { ffi::Array IterMapSimplify(const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, IterMapLevel check_level, - arith::Analyzer* ana, bool simplify_trivial_iterators) { + const arith::Analyzer& ana, bool simplify_trivial_iterators) { + arith::AnalyzerObj* ana_ptr = ana.get(); if (!IterRangeSanityCheck(input_iters)) return indices; auto res = DetectIterMap(indices, input_iters, input_pred, check_level, ana, /*simplify_trivial_iterators=*/simplify_trivial_iterators); @@ -2173,7 +2176,7 @@ ffi::Array IterMapSimplify(const ffi::Array& indices, } ffi::Array simplified; simplified.reserve(rewrite.size()); - IterMapToExprNormalizer converter(ana); + IterMapToExprNormalizer converter(ana_ptr); for (const auto& expr : rewrite) simplified.push_back(converter.Convert(expr)); return simplified; } @@ -2185,7 +2188,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](const ffi::Array& indices, const ffi::Map& input_iters, const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { arith::Analyzer ana; - return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), &ana, + return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), ana, simplify_trivial_iterators); }); } @@ -2205,7 +2208,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { */ class SubspaceDivider { public: - explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector, + explicit SubspaceDivider(AnalyzerObj* analyzer, const IterMarkSplitCollector& collector, const std::unordered_set& sub_iters) : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {} @@ -2470,7 +2473,7 @@ class SubspaceDivider { size_t unresolved_count_{0}; // arithmetic analyzer used to call CanProve - Analyzer* analyzer_; + AnalyzerObj* analyzer_; // collector that collects the outgoing split reference of each IterMark const IterMarkSplitCollector collector_; // the set of subspace iters @@ -2486,8 +2489,9 @@ ffi::Array> SubspaceDivide(const ffi::Array& bind const ffi::Map& input_iters, const ffi::Array& sub_iters, const PrimExpr& predicate, IterMapLevel check_level, - arith::Analyzer* analyzer, + const arith::Analyzer& analyzer, bool simplify_trivial_iterators) { + arith::AnalyzerObj* analyzer_ptr = analyzer.get(); if (!IterRangeSanityCheck(input_iters)) return ffi::Array>(); auto res = DetectIterMap(bindings, input_iters, predicate, check_level, analyzer, simplify_trivial_iterators); @@ -2501,7 +2505,7 @@ ffi::Array> SubspaceDivide(const ffi::Array& bind IterMarkSplitCollector collector; collector.Collect(maps); - SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set); + SubspaceDivider subspace_divider(analyzer_ptr, collector, inner_iter_set); std::vector> results; for (const IterSumExpr& expr : maps) { @@ -2525,13 +2529,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { bool simplify_trivial_iterators) { arith::Analyzer ana; return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), - &ana, simplify_trivial_iterators); + ana, simplify_trivial_iterators); }); } class InverseAffineIterMapTransformer { public: - explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {} + explicit InverseAffineIterMapTransformer(AnalyzerObj* analyzer) : analyzer_(analyzer) {} ffi::Map operator()(const ffi::Array& iter_map, const ffi::Array& outputs) { @@ -2649,7 +2653,7 @@ class InverseAffineIterMapTransformer { } } - Analyzer* analyzer_; + AnalyzerObj* analyzer_; ffi::Map backprop_; // the accumulator of backpropgation ffi::Map inverse_; // the result of inverse transformation }; @@ -2657,7 +2661,7 @@ class InverseAffineIterMapTransformer { ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, const ffi::Array outputs) { Analyzer analyzer; - return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); + return InverseAffineIterMapTransformer(analyzer.get())(iter_map, outputs); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index f01972351b3e..d10a7bad2932 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -99,7 +99,7 @@ struct ModularSetAnalyzer::Entry { class ModularSetAnalyzer::Impl : public ExprFunctor { public: - explicit Impl(Analyzer* parent) : parent_(parent) {} + explicit Impl(AnalyzerObj* parent) : parent_(parent) {} void Update(const Var& var, const ModularSet& info, bool allow_override) { if (!allow_override) { @@ -312,7 +312,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor var_map_; /*! @@ -403,7 +403,7 @@ std::function ModularSetAnalyzer::EnterConstraint(const PrimExpr& constr return impl_->EnterConstraint(constraint); } -ModularSetAnalyzer::ModularSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} +ModularSetAnalyzer::ModularSetAnalyzer(AnalyzerObj* parent) : impl_(new Impl(parent)) {} ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index c36a19349305..bbe330147cf9 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -104,7 +104,7 @@ PresburgerSet::PresburgerSet(const PrimExpr& constraint) { }); auto constraints_union = ExtractComponents(constraint); Analyzer analyzer; - PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); + PrimExpr simplified_constraint = analyzer->Simplify(constraint, kSimplifyRewriteCanonicalRewrite); auto space = PresburgerSpace::getRelationSpace(vars.size(), 0, 0, 0); auto node = ffi::make_object(std::move(space), vars); node->SetVars(vars); @@ -120,7 +120,7 @@ PresburgerSet::PresburgerSet(const std::vector& disjuncts, void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const ffi::Array& vars) { Analyzer analyzer; - PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite); + PrimExpr simplified_constraint = analyzer->Simplify(constraint, kSimplifyRewriteCanonicalRewrite); Update(simplified_constraint, this); SetVars(vars); } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ac2939f53063..2120aaa1a859 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -2459,7 +2459,7 @@ void RewriteSimplifier::SetMaximumRewriteSteps(int64_t maximum) { impl_->SetMaximumRewriteSteps(maximum); } -RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} +RewriteSimplifier::RewriteSimplifier(AnalyzerObj* parent) : impl_(new Impl(parent)) {} RewriteSimplifier::~RewriteSimplifier() { delete impl_; } diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 5f2af7b81705..b42b73336a27 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -88,7 +88,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { public: using IRMutatorWithAnalyzer::VisitExpr_; - explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} + explicit Impl(AnalyzerObj* parent) : IRMutatorWithAnalyzer(parent) {} PrimExpr VisitExpr(const PrimExpr& e) override; diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 4b6ac036e8bb..623a906ee75c 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -284,7 +284,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol std::vector rest; Analyzer analyzer_problem; - analyzer_problem.Bind(system_to_solve->ranges); + analyzer_problem->Bind(system_to_solve->ranges); size_t num_vars = system_to_solve->variables.size(); @@ -303,7 +303,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol if (const tirx::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] ffi::Array coeffs = arith::DetectLinearEquation( - analyzer_problem.Simplify(eq->a - eq->b), system_to_solve->variables); + analyzer_problem->Simplify(eq->a - eq->b), system_to_solve->variables); if (!coeffs.empty()) { std::vector row; for (size_t j = 0; j < coeffs.size() - 1; ++j) { @@ -348,7 +348,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // Simplify right hand sides for (PrimExpr r : Uy) { - r = analyzer_problem.Simplify(r); + r = analyzer_problem->Simplify(r); } // Create the relations of the existence of a solution @@ -362,7 +362,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // is a divisor of the Ub[j] new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0); } - new_relation = analyzer_problem.Simplify(new_relation); + new_relation = analyzer_problem->Simplify(new_relation); if (tirx::is_const_int(new_relation, 0)) { // unable to solve the system. return IntConstraintsTransform(system_to_solve, @@ -390,7 +390,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol for (size_t j = 0; j < num_vars; ++j) { if (j >= S.size() || S[j][j] == 0) { // The j-th variable can take any integer value, create a tvm variable for it - PrimExpr to_old = analyzer_problem.Simplify(V_inv_x[j]); + PrimExpr to_old = analyzer_problem->Simplify(V_inv_x[j]); std::string name_hint = "n" + std::to_string(new_vars.size()); if (const VarNode* v_old = to_old.as()) { name_hint += "_" + v_old->name_hint; @@ -404,12 +404,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { PrimExpr a = tirx::make_const(Uy[j].dtype(), S[j][j]); - solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a))); + solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers PrimExpr a = tirx::make_const(Uy[j].dtype(), -S[j][j]); - solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a))); + solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(-Uy[j], a))); } } } @@ -420,7 +420,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol for (size_t j = 0; j < num_vars; ++j) { e = e + tirx::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; } - e = analyzer_problem.Simplify(e); + e = analyzer_problem->Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); } @@ -428,7 +428,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol ffi::Map new_ranges = InferRange(new_to_old_map, system_to_solve->variables, system_to_solve->ranges); Analyzer analyzer_solution; - analyzer_solution.Bind(new_ranges); + analyzer_solution->Bind(new_ranges); // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. @@ -436,9 +436,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) { const Range& old_range = system_to_solve->ranges.at(old_var); PrimExpr express_by_new_vars = old_to_new_map.at(old_var); - PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); + PrimExpr lower_cond = analyzer_solution->Simplify(old_range->min <= express_by_new_vars); PrimExpr upper_cond = - analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); + analyzer_solution->Simplify(express_by_new_vars < old_range->min + old_range->extent); if (!tirx::is_const_int(lower_cond, 1)) { new_relations.push_back(lower_cond); } diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 64a85d04d70b..aa66dcf5a655 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -92,15 +92,15 @@ class NormalizeComparisons : public ExprMutator { PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { // rewrite LT to LE for ints if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { - return LE(analyzer_.Simplify(a - b + 1), make_zero(a.dtype())); + return LE(analyzer_->Simplify(a - b + 1), make_zero(a.dtype())); } - return T(analyzer_.Simplify(a - b), make_zero(a.dtype())); + return T(analyzer_->Simplify(a - b), make_zero(a.dtype())); } arith::Analyzer analyzer_; }; void AddInequality(std::vector* inequality_set, const PrimExpr& new_ineq, - Analyzer* analyzer) { + AnalyzerObj* analyzer) { if (analyzer->CanProve(new_ineq) || std::find_if(inequality_set->begin(), inequality_set->end(), [&](const PrimExpr& e) { return ffi::StructuralEqual()(e, new_ineq); @@ -128,7 +128,8 @@ void AddInequality(std::vector* inequality_set, const PrimExpr& new_in void ClassifyByPolarity(const Var& var, const std::vector& current_ineq_set, std::vector* next_ineq_set, std::vector* rest, std::vector>* coef_pos, - std::vector>* coef_neg, Analyzer* analyzer) { + std::vector>* coef_neg, + AnalyzerObj* analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { @@ -188,7 +189,7 @@ void MoveEquality(std::vector* upper_bounds, std::vector* lo PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_to_solve) { arith::Analyzer analyzer; - analyzer.Bind(system_to_solve->ranges); + analyzer->Bind(system_to_solve->ranges); // The algorithm consists in doing the following things for each variable v // - Take formulas from `current_ineq_set_to_solve` and @@ -213,9 +214,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Simplify each inequality into the form `expr <= 0` and add to current formulas for (const PrimExpr& ineq : system_to_solve->relations) { - AddInequality(¤t_ineq_set_to_solve, - NormalizeComparisons()(analyzer.Simplify(ineq, kSimplifyRewriteCanonicalRewrite)), - &analyzer); + AddInequality( + ¤t_ineq_set_to_solve, + NormalizeComparisons()(analyzer->Simplify(ineq, kSimplifyRewriteCanonicalRewrite)), + analyzer.get()); } ffi::Map res_bounds; @@ -231,15 +233,15 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Add bounds from vranges if (system_to_solve->ranges.count(v)) { const Range& range = system_to_solve->ranges[v]; - PrimExpr range_lbound = analyzer.Simplify(range->min, kSimplifyRewriteCanonicalRewrite); + PrimExpr range_lbound = analyzer->Simplify(range->min, kSimplifyRewriteCanonicalRewrite); PrimExpr range_ubound = - analyzer.Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite); + analyzer->Simplify(range->min + range->extent - 1, kSimplifyRewriteCanonicalRewrite); coef_neg.push_back({-1, range_lbound}); coef_pos.push_back({1, -range_ubound}); } ClassifyByPolarity(v, current_ineq_set_to_solve, &next_ineq_set_to_solve, &rest, &coef_pos, - &coef_neg, &analyzer); + &coef_neg, analyzer.get()); // Combine each positive inequality with each negative one (by adding them together) int64_t gcd_x, gcd_y; @@ -255,8 +257,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // with steps = 2 it's (y*2) - 10 <= 0 new_ineq = - NormalizeComparisons()(analyzer.Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite)); - AddInequality(&next_ineq_set_to_solve, new_ineq, &analyzer); + NormalizeComparisons()(analyzer->Simplify(new_ineq, kSimplifyRewriteCanonicalRewrite)); + AddInequality(&next_ineq_set_to_solve, new_ineq, analyzer.get()); } } @@ -280,17 +282,17 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t for (const auto& pos : coef_pos) { PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second; - bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite); + bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(upper_bounds.begin(), upper_bounds.end(), [&bound, &analyzer](const PrimExpr& o) { - return analyzer.CanProve(o - bound <= 0); + return analyzer->CanProve(o - bound <= 0); })) { continue; } // Erase all worse bounds for (auto iter = upper_bounds.begin(); iter != upper_bounds.end();) { - if (analyzer.CanProve(*iter - bound >= 0)) { + if (analyzer->CanProve(*iter - bound >= 0)) { iter = upper_bounds.erase(iter); } else { ++iter; @@ -301,17 +303,17 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } for (const auto& neg : coef_neg) { PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; - bound = analyzer.Simplify(bound, kSimplifyRewriteCanonicalRewrite); + bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(lower_bounds.begin(), lower_bounds.end(), [&bound, &analyzer](const PrimExpr& o) { - return analyzer.CanProve(o - bound >= 0); + return analyzer->CanProve(o - bound >= 0); })) { continue; } // Erase all worse bounds for (auto iter = lower_bounds.begin(); iter != lower_bounds.end();) { - if (analyzer.CanProve(*iter - bound <= 0)) { + if (analyzer->CanProve(*iter - bound <= 0)) { iter = lower_bounds.erase(iter); } else { ++iter; @@ -340,7 +342,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // Everything that is left goes to res.relations ffi::Array other_conditions; for (const PrimExpr& e : current_ineq_set_to_solve) { - PrimExpr e_simp = analyzer.Simplify(e, kSimplifyRewriteCanonicalRewrite); + PrimExpr e_simp = analyzer->Simplify(e, kSimplifyRewriteCanonicalRewrite); if (is_const_int(e_simp, 0)) { // contradiction detected other_conditions = {const_false()}; @@ -385,7 +387,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // This order is needed to compute new ranges. for (auto it = inequalities->variables.rbegin(); it != inequalities->variables.rend(); ++it) { arith::Analyzer analyzer; - analyzer.Bind(vranges); + analyzer->Bind(vranges); const Var& var = *it; TVM_FFI_ICHECK(solved_bounds.count(var)); @@ -397,7 +399,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // The MSVC compiler optimization must be disabled for the expression `bnd->equal[0]` which // triggers an internal compiler error. Range best_range(bnd->equal[0], - analyzer.Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite)); + analyzer->Simplify(bnd->equal[0] + 1, kSimplifyRewriteCanonicalRewrite)); res_ranges.Set(var, best_range); vranges.Set(var, best_range); } else { @@ -408,7 +410,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { auto best_range = bnd.FindBestRange(vranges); if (best_range.defined()) { - if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { + if (analyzer->CanProveGreaterEqual(-best_range->extent, 0)) { // range.extent <= 0 implies the input inequality system is unsolvable return IntConstraints(/*variables=*/{}, /*ranges=*/{}, /*relations=*/{tirx::make_zero(DataType::Bool())}); @@ -421,10 +423,10 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { // Add the original conditions to the resulting conditions arith::Analyzer analyzer; - analyzer.Bind(vranges); + analyzer->Bind(vranges); for (const PrimExpr& old_cond : AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) { - if (!analyzer.CanProve(old_cond)) { + if (!analyzer->CanProve(old_cond)) { // those not represented in vranges (res_ranges) res_relations.push_back(old_cond); } @@ -459,7 +461,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ for (std::pair vr : inequalities->ranges) { vranges.Set(vr.first, vr.second); } - analyzer.Bind(vranges); + analyzer->Bind(vranges); // We process variables in the reverse direction to start with the most independent one. // This order is needed to compute new ranges. @@ -490,7 +492,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ } else if (is_const_int(best_range->extent, 1)) { // Don't create an itervar, just replace it everywhere with its min res_src_to_dst.Set(var, best_range->min); - } else if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { + } else if (analyzer->CanProveGreaterEqual(-best_range->extent, 0)) { // range.extent <= 0 implies the input inequality system is unsolvable return IntConstraintsTransform(inequalities, IntConstraints( @@ -504,7 +506,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ // Note that we are substituting old with new, so best_range contains new var, // that is we have to substitute new with old in best_range here res_dst_to_src.Set(new_var, - analyzer.Simplify(var - Substitute(best_range->min, res_dst_to_src))); + analyzer->Simplify(var - Substitute(best_range->min, res_dst_to_src))); // Add the new var to the resulting axis auto range = Range(make_zero(new_var.dtype()), best_range->extent); @@ -512,7 +514,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ res_ranges.Set(new_var, range); vranges.Set(new_var, range); - analyzer.Bind(new_var, range); + analyzer->Bind(new_var, range); } } } @@ -520,7 +522,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ // Add the original conditions (with variables substituted) to the resulting conditions for (const PrimExpr& old_cond : AsConditions(inequalities->variables, solved_bounds, solved_other_relations)) { - PrimExpr new_cond = analyzer.Simplify(Substitute(old_cond, res_src_to_dst)); + PrimExpr new_cond = analyzer->Simplify(Substitute(old_cond, res_src_to_dst)); if (!is_const_int(new_cond, 1)) { // those not represented in vranges (res_ranges) res_relations.push_back(new_cond); diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index dcee90c9a7ec..e1c81da0b776 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -48,7 +48,7 @@ static bool IsBijectiveAffine(const IndexMap& m, const ffi::Array& ranges } arith::Analyzer analyzer; auto iter_map_result = DetectIterMap(m->final_indices, input_iters, /* predicate = */ 1, - /*check_level=*/arith::IterMapLevel::Bijective, &analyzer, + /*check_level=*/arith::IterMapLevel::Bijective, analyzer, /*simplify_trivial_iterators=*/true); return !iter_map_result->indices.empty(); } @@ -176,9 +176,9 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { ffi::Array t1_initial_indices = t1->initial_indices.Map([](tirx::Var i) -> PrimExpr { return i; }); arith::Analyzer analyzer; - auto t0_output = t0->MapIndices(t1_initial_indices, &analyzer); + auto t0_output = t0->MapIndices(t1_initial_indices, analyzer); for (size_t i = 0; i < t0_output.size(); ++i) { - if (!analyzer.CanProveEqual(t0_output[i], t1->final_indices[i])) return false; + if (!analyzer->CanProveEqual(t0_output[i], t1->final_indices[i])) return false; } return true; } @@ -448,7 +448,7 @@ class BlockAnalyzer : public StmtExprVisitor { SpatialLayout DetectBufferAccessIterMap(ffi::Array indices) { auto result = arith::DetectIterMap( /*indices=*/indices, /*input_iters*/ spatial_dom_, - /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck, &arith_analyzer_); + /*predicate*/ 1, /*check_level*/ arith::IterMapLevel::NoCheck, arith_analyzer_); if (result->indices.empty()) { DLOG(INFO) << "[LayoutInference] Failed to analyze indices " << indices << ", error: " << result->errors; diff --git a/src/relax/analysis/shape_analysis.cc b/src/relax/analysis/shape_analysis.cc index e2f624937773..df4bd8376f01 100644 --- a/src/relax/analysis/shape_analysis.cc +++ b/src/relax/analysis/shape_analysis.cc @@ -30,7 +30,7 @@ namespace tvm { namespace relax { bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Array& rhs, - arith::Analyzer* ana) { + const arith::Analyzer& ana) { if (lhs.same_as(rhs)) return true; if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) { @@ -39,7 +39,7 @@ bool CanProveShapeEqual(const ffi::Array& lhs, const ffi::Array(); auto* rhs_shape = rhs.as(); diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 66062c1870c3..303148c64938 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -121,7 +121,7 @@ class WellDefinedEraser : public StructInfoMutator, public: WellDefinedEraser(std::function(const tirx::Var& var)> f_shape_var_map, std::function(const Var& var)> f_var_map, - arith::Analyzer* ana) + arith::AnalyzerObj* ana) : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} StructInfo VisitStructInfo_(const PrimStructInfoNode* op) final { @@ -254,23 +254,32 @@ class WellDefinedEraser : public StructInfoMutator, bool has_undefined_ = false; std::function(const tirx::Var& var)> f_shape_var_map_; std::function(const Var& var)> f_var_map_; - arith::Analyzer* ana_; + arith::AnalyzerObj* ana_; }; StructInfo EraseToWellDefined( const StructInfo& info, std::function(const tirx::Var& var)> f_shape_var_map, - std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { - if (ana == nullptr) { - arith::Analyzer inst; - return WellDefinedEraser(f_shape_var_map, f_var_map, &inst).VisitStructInfo(info); - } else { - return WellDefinedEraser(f_shape_var_map, f_var_map, ana).VisitStructInfo(info); - } + std::function(const Var& var)> f_var_map) { + arith::Analyzer analyzer; + return EraseToWellDefined(info, f_shape_var_map, f_var_map, analyzer); +} + +StructInfo EraseToWellDefined( + const StructInfo& info, + std::function(const tirx::Var& var)> f_shape_var_map, + std::function(const Var& var)> f_var_map, const arith::Analyzer& ana) { + return WellDefinedEraser(f_shape_var_map, f_var_map, ana.get()).VisitStructInfo(info); } StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, - ffi::Map var_map, arith::Analyzer* ana) { + ffi::Map var_map) { + arith::Analyzer analyzer; + return EraseToWellDefined(info, shape_var_map, var_map, analyzer); +} + +StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, + ffi::Map var_map, const arith::Analyzer& ana) { std::function(const tirx::Var& var)> f_shape_var_map = nullptr; std::function(const Var& var)> f_var_map = nullptr; @@ -307,7 +316,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { class StructInfoBaseChecker : public StructInfoFunctor { public: - explicit StructInfoBaseChecker(arith::Analyzer* ana) : analyzer_(ana) {} + explicit StructInfoBaseChecker(arith::AnalyzerObj* ana) : analyzer_(ana) {} BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { // quick path @@ -485,7 +494,7 @@ class StructInfoBaseChecker protected: // analyzer - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; // struct equal checker ffi::StructuralEqual struct_equal_; @@ -596,14 +605,14 @@ class StructInfoBaseChecker } }; +BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived) { + arith::Analyzer analyzer; + return StructInfoBaseCheck(base, derived, analyzer); +} + BaseCheckResult StructInfoBaseCheck(const StructInfo& base, const StructInfo& derived, - arith::Analyzer* ana) { - if (ana == nullptr) { - arith::Analyzer inst; - return StructInfoBaseChecker(&inst)(base, derived); - } else { - return StructInfoBaseChecker(ana)(base, derived); - } + const arith::Analyzer& ana) { + return StructInfoBaseChecker(ana.get())(base, derived); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -614,7 +623,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -bool IsBaseOf(const StructInfo& base, const StructInfo& derived, arith::Analyzer* ana) { +bool IsBaseOf(const StructInfo& base, const StructInfo& derived) { + arith::Analyzer analyzer; + return IsBaseOf(base, derived, analyzer); +} + +bool IsBaseOf(const StructInfo& base, const StructInfo& derived, const arith::Analyzer& ana) { return StructInfoBaseCheck(base, derived, ana) == BaseCheckResult::kPass; } @@ -833,7 +847,7 @@ PrimExpr StructInfoBaseCheckPrecondition(const StructInfo& base, const StructInf // from the expressions in arg(rhs) to var in param. class CallRetStructInfoDeriver : public StructInfoBaseChecker { public: - explicit CallRetStructInfoDeriver(arith::Analyzer* ana) : StructInfoBaseChecker(ana) {} + explicit CallRetStructInfoDeriver(arith::AnalyzerObj* ana) : StructInfoBaseChecker(ana) {} // No short cut, so we can recursively populate all pairs. BaseCheckResult VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { @@ -930,7 +944,9 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { } else { // Best effort prove. Expr mapped_value = (*it).second; - if (CanProveShapeEqual(mapped_value, rhs, analyzer_)) return BaseCheckResult::kPass; + if (CanProveShapeEqual(mapped_value, rhs, ffi::GetRef(analyzer_))) { + return BaseCheckResult::kPass; + } return BaseCheckResult::kFailL2; } } @@ -962,13 +978,14 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { }; StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, - const BlockBuilder& ctx, arith::Analyzer* ana) { - if (ana == nullptr) { - arith::Analyzer inst; - return CallRetStructInfoDeriver(&inst).Derive(finfo, call, ctx); - } else { - return CallRetStructInfoDeriver(ana).Derive(finfo, call, ctx); - } + const BlockBuilder& ctx) { + arith::Analyzer analyzer; + return DeriveCallRetStructInfo(finfo, call, ctx, analyzer); +} + +StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Call& call, + const BlockBuilder& ctx, const arith::Analyzer& ana) { + return CallRetStructInfoDeriver(ana.get()).Derive(finfo, call, ctx); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -985,7 +1002,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { class StructInfoLCAFinder : public StructInfoFunctor { public: - explicit StructInfoLCAFinder(arith::Analyzer* ana) : analyzer_(ana) {} + explicit StructInfoLCAFinder(arith::AnalyzerObj* ana) : analyzer_(ana) {} StructInfo VisitStructInfo(const StructInfo& lhs, const StructInfo& other) final { // quick path @@ -1028,7 +1045,8 @@ class StructInfoLCAFinder int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim; if (lhs->ndim != rhs->ndim || !lhs->values.defined() || !rhs->values.defined() || - !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), analyzer_)) { + !CanProveShapeEqual(lhs->values.value(), rhs->values.value(), + ffi::GetRef(analyzer_))) { // prefers return same when possible if (!lhs->values.defined() && lhs->ndim == ndim) { return ffi::GetRef(lhs); @@ -1055,7 +1073,8 @@ class StructInfoLCAFinder // if ndim mismatch or one side of shape is missing // then we cannot keep in symbolic shape if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || !rhs->shape.defined() || - !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), analyzer_)) { + !CanProveShapeEqual(lhs->shape.value(), rhs->shape.value(), + ffi::GetRef(analyzer_))) { // reuse lhs when possible if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim && (!lhs->vdevice.defined() || vdev.defined())) { @@ -1154,7 +1173,7 @@ class StructInfoLCAFinder private: // analyzer - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; // struct equal checker ffi::StructuralEqual struct_equal_; @@ -1168,13 +1187,13 @@ class StructInfoLCAFinder } }; -StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana) { - if (ana == nullptr) { - arith::Analyzer inst; - return StructInfoLCAFinder(&inst)(lhs, rhs); - } else { - return StructInfoLCAFinder(ana)(lhs, rhs); - } +StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs) { + arith::Analyzer analyzer; + return StructInfoLCA(lhs, rhs, analyzer); +} + +StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, const arith::Analyzer& ana) { + return StructInfoLCAFinder(ana.get())(lhs, rhs); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 26041475c64d..6fb6e8549bbb 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -370,7 +370,7 @@ bool HasReshapePattern(const PrimFunc& func) { : is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer) {} void VisitStmt_(const ForNode* loop) final { - ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + ana_->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); // To detect the reshape pattern, we require each For to have // either another For or a BlockRealize as body. if (!(loop->body->IsInstance() || loop->body->IsInstance())) { @@ -408,7 +408,7 @@ bool HasReshapePattern(const PrimFunc& func) { ffi::Map var_range; for (const IterVar& v : block->iter_vars) { - ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); + ana_->Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); var_range.Set(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); } @@ -441,13 +441,13 @@ bool HasReshapePattern(const PrimFunc& func) { for (int i = 0; i < ndim; ++i) { idx = idx * buffer->shape[i] + indices[i]; } - idx = ana_.Simplify(idx); + idx = ana_->Simplify(idx); return arith::IterMapSimplify( /*indices=*/{idx}, /*input_iters=*/var_range, /*input_pred=*/const_true(), /*check_level=*/arith::IterMapLevel::Surjective, - /*analyzer=*/&ana_, + /*analyzer=*/ana_, /*simplify_trivial_iterators=*/true)[0]; }; @@ -458,9 +458,9 @@ bool HasReshapePattern(const PrimFunc& func) { } for (int i = 0; i < static_cast(block->iter_vars.size()); ++i) { if (!(indices[i].same_as(block->iter_vars[i]->var) && - this->ana_.CanProveEqual(block->iter_vars[i]->dom->min, - IntImm(DataType::Int(64), /*value=*/0)) && - this->ana_.CanProveEqual(buffer->shape[i], block->iter_vars[i]->dom->extent))) { + this->ana_->CanProveEqual(block->iter_vars[i]->dom->min, + IntImm(DataType::Int(64), /*value=*/0)) && + this->ana_->CanProveEqual(buffer->shape[i], block->iter_vars[i]->dom->extent))) { return false; } } @@ -497,7 +497,7 @@ bool HasReshapePattern(const PrimFunc& func) { /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, /*input_pred=*/const_true(), /*check_level=*/arith::IterMapLevel::Surjective, - /*analyzer=*/&this->ana_, + /*analyzer=*/this->ana_, /*simplify_trivial_iterators=*/true); TVM_FFI_ICHECK_EQ(simplify_res.size(), 1); @@ -512,7 +512,7 @@ bool HasReshapePattern(const PrimFunc& func) { PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, buffer_load->indices); PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, buffer_store->indices); // Check if we can prove the equality of flattened indices. - if (ana_.CanProveEqual(src_idx, dst_idx)) { + if (ana_->CanProveEqual(src_idx, dst_idx)) { this->is_reshape_ = true; return; } diff --git a/src/relax/distributed/axis_group_graph.cc b/src/relax/distributed/axis_group_graph.cc index c805ea6a5c7f..1252c46ee6af 100644 --- a/src/relax/distributed/axis_group_graph.cc +++ b/src/relax/distributed/axis_group_graph.cc @@ -31,7 +31,7 @@ namespace tvm { namespace tirx { Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, - arith::Analyzer* analyzer) { + const arith::Analyzer& analyzer) { if (index.as()) { return Downcast(index); } @@ -128,7 +128,7 @@ void BuildAxisGraphBinary(const Var& output_var, const Call& call, for (int i = 1; i <= std::min(x1_ndim, x2_ndim); ++i) { const PrimExpr& dim0 = x1_shape->values[x1_ndim - i]; const PrimExpr& dim1 = x2_shape->values[x2_ndim - i]; - if (analyzer.CanProveEqual(dim0, dim1)) { + if (analyzer->CanProveEqual(dim0, dim1)) { // join batch dim axis_group_graph->JoinAxis({tensor_list[0].get(), x1_ndim - i}, {tensor_list[2].get(), std::max(x1_ndim, x2_ndim) - i}, @@ -136,11 +136,11 @@ void BuildAxisGraphBinary(const Var& output_var, const Call& call, axis_group_graph->JoinAxis({tensor_list[1].get(), x2_ndim - i}, {tensor_list[2].get(), std::max(x1_ndim, x2_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); - } else if (analyzer.CanProveEqual(dim0, 1)) { + } else if (analyzer->CanProveEqual(dim0, 1)) { axis_group_graph->JoinAxis({tensor_list[1].get(), x2_ndim - i}, {tensor_list[2].get(), std::max(x1_ndim, x2_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); - } else if (analyzer.CanProveEqual(dim1, 1)) { + } else if (analyzer->CanProveEqual(dim1, 1)) { axis_group_graph->JoinAxis({tensor_list[0].get(), x1_ndim - i}, {tensor_list[2].get(), std::max(x1_ndim, x2_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); @@ -242,18 +242,18 @@ void BuildAxisGraphMatmul(const Var& output_var, const Call& call, const PrimExpr& dim0 = x1_shape_prefix[x1_prefix_ndim - i]; const PrimExpr& dim1 = x2_shape_prefix[x2_prefix_ndim - i]; // join batch dim - if (analyzer.CanProveEqual(dim0, dim1)) { + if (analyzer->CanProveEqual(dim0, dim1)) { axis_group_graph->JoinAxis({x1.get(), x1_prefix_ndim - i}, {x3.get(), std::max(x1_prefix_ndim, x2_prefix_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); axis_group_graph->JoinAxis({x2.get(), x2_prefix_ndim - i}, {x3.get(), std::max(x1_prefix_ndim, x2_prefix_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); - } else if (analyzer.CanProveEqual(dim0, 1)) { + } else if (analyzer->CanProveEqual(dim0, 1)) { axis_group_graph->JoinAxis({x2.get(), x2_prefix_ndim - i}, {x3.get(), std::max(x1_prefix_ndim, x2_prefix_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); - } else if (analyzer.CanProveEqual(dim1, 1)) { + } else if (analyzer->CanProveEqual(dim1, 1)) { axis_group_graph->JoinAxis({x1.get(), x1_prefix_ndim - i}, {x3.get(), std::max(x1_prefix_ndim, x2_prefix_ndim) - i}, distributed::AxisGroupGraph::EdgeType::kDescend); @@ -320,10 +320,10 @@ void BuildAxisGraphReshape(const Var& output_var, const Call& call, PrimExpr old_shape_product = 1, new_shape_product = 1; arith::Analyzer analyzer_; while (i > 0 && j > 0) { - if (analyzer_.CanProve(new_shape_product > old_shape_product)) { + if (analyzer_->CanProve(new_shape_product > old_shape_product)) { i--; old_shape_product *= old_shape_values[i]; - } else if (analyzer_.CanProve(new_shape_product < old_shape_product)) { + } else if (analyzer_->CanProve(new_shape_product < old_shape_product)) { j--; new_shape_product *= new_shape_values[j]; } else { diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 6fba0cd4c641..6984b00d8101 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -230,7 +230,7 @@ class DistributedBufferCompactor : StmtExprMutator { for (const auto& pr : dim_shards) { int dim = pr.first; int shard = pr.second; - Var var = GetShardingVarFromIndex(access_index[dim], iter_var_range, &analyzer); + Var var = GetShardingVarFromIndex(access_index[dim], iter_var_range, analyzer); TVM_FFI_ICHECK(!iter_var_shards_.count(var) || iter_var_shards_[var] == shard) << "A loop cannot have different sharding"; iter_var_shards_[var] = shard; @@ -246,7 +246,7 @@ class DistributedBufferCompactor : StmtExprMutator { Range dom = iter_var->dom; TVM_FFI_ICHECK(is_zero(dom->min)); arith::Analyzer analyzer; - TVM_FFI_ICHECK(analyzer.CanProve(floormod(dom->extent, shard) == 0)); + TVM_FFI_ICHECK(analyzer->CanProve(floormod(dom->extent, shard) == 0)); new_iter_vars.push_back( IterVar(Range::FromMinExtent(dom->min, floordiv(dom->extent, shard)), iter_var->var, iter_var->iter_type, iter_var->thread_tag)); @@ -334,7 +334,7 @@ class DistributedBufferCompactor : StmtExprMutator { int shard = loop_var_shards_[op->loop_var]; if (shard > 1) { arith::Analyzer analyzer; - TVM_FFI_ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0)); + TVM_FFI_ICHECK(analyzer->CanProve(floormod(new_loop->extent, shard) == 0)); new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard); return new_loop; } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 1061c02eb1f8..cdcedd298485 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -219,11 +219,11 @@ class BlockBuilderImpl : public BlockBuilderNode { // of shape inference. In many cases, knowning that the // shape variable is non-negative allows for simpler // expressions for dynamic shapes. - analyzer_.MarkGlobalNonNegValue(shape_var); + analyzer_->MarkGlobalNonNegValue(shape_var); } else { const PrimExpr& old_shape_expr = (*it).second; TVM_FFI_ICHECK(old_shape_expr.same_as(shape_expr) || - analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + analyzer_->CanProveEqual(old_shape_expr, shape_expr)) << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " << shape_expr; } @@ -307,7 +307,7 @@ class BlockBuilderImpl : public BlockBuilderNode { } } - arith::Analyzer* GetAnalyzer() final { return &analyzer_; } + arith::Analyzer GetAnalyzer() final { return analyzer_; } protected: /*! @@ -855,7 +855,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(call->op); TVM_FFI_ICHECK(opt) << "Call->op must contains a function struct info"; FuncStructInfo finfo = opt.value(); - return DeriveCallRetStructInfo(finfo, call, ffi::GetRef(this), &analyzer_); + return DeriveCallRetStructInfo(finfo, call, ffi::GetRef(this), analyzer_); } } diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index 57f17bdbbcce..7344d05ec7e8 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -190,7 +190,7 @@ static std::optional TryMatch(const PNode& p, const RNode& r, static std::optional TryValidate( const MatchState& current_match, const std::unordered_map& pattern2node, - const std::vector& validation_constraints, arith::Analyzer* analyzer) { + const std::vector& validation_constraints, arith::AnalyzerObj* analyzer) { MatchState new_match; std::function(const DFPatternNode*)> query_match_state = @@ -244,7 +244,7 @@ static std::optional MatchTree( const std::unordered_map& pattern2node, const std::unordered_map& var2node, DFPatternMatcher* matcher, const std::vector& roots, const std::vector& validation_constraints, - const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { + const MatcherUseDefAnalysis& ud_analysis, arith::AnalyzerObj* analyzer) { auto get_next_root = [&](size_t root_idx) -> const PNode* { // Look for the next unmatched root node. for (; root_idx < roots.size(); ++root_idx) { @@ -348,7 +348,7 @@ ffi::Optional> MatchGraph(const PatternContext& ctx, arith::Analyzer analyzer; auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, - ctx->validation_constraints, ud_analysis, &analyzer); + ctx->validation_constraints, ud_analysis, analyzer.get()); if (!match) { return std::nullopt; } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 57578773c675..ad653087a088 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -55,6 +55,7 @@ namespace tvm { namespace relax { using tvm::arith::Analyzer; +using tvm::arith::AnalyzerObj; /*! * \brief Match the attributes of an object. @@ -476,10 +477,10 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { sorted_condition = sorted_condition && constraint; } - return analyzer_.Simplify(sorted_condition); + return analyzer_->Simplify(sorted_condition); } -static bool ShapeEqual(Analyzer* analyzer, const ffi::Array& lhs, +static bool ShapeEqual(AnalyzerObj* analyzer, const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) @@ -491,7 +492,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& e // no need to jump, as var.shape == value.shape if (const auto* tinfo = GetStructInfoAs(expr)) { if (const ShapeExprNode* shape_expr = tinfo->shape.as()) { - return ShapeEqual(&analyzer_, op->shape, shape_expr->values) && + return ShapeEqual(analyzer_.get(), op->shape, shape_expr->values) && VisitDFPattern(op->pattern, expr); } } @@ -564,7 +565,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { auto expr = UnwrapBindings(expr0, var2val_); if (const ShapeExprNode* shape_expr = expr.as()) - return ShapeEqual(&analyzer_, op->fields, shape_expr->values); + return ShapeEqual(analyzer_.get(), op->fields, shape_expr->values); return false; } diff --git a/src/relax/op/ccl/ccl.cc b/src/relax/op/ccl/ccl.cc index 7f7eb3c8935d..6885dd7a6f02 100644 --- a/src/relax/op/ccl/ccl.cc +++ b/src/relax/op/ccl/ccl.cc @@ -146,7 +146,7 @@ StructInfo InferStructInfoScatter(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); auto input_shape = input_sinfo->GetShape(); TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of scatter_from_worker0 should have defined shape."; diff --git a/src/relax/op/distributed/distributed.cc b/src/relax/op/distributed/distributed.cc index bee2751564d9..2ef8725c8912 100644 --- a/src/relax/op/distributed/distributed.cc +++ b/src/relax/op/distributed/distributed.cc @@ -160,7 +160,7 @@ StructInfo InferStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); auto input_shape = input_sinfo->GetShape(); TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of redistribute_replica_to_shard should have defined shape."; @@ -188,7 +188,7 @@ StructInfo InferDistStructInfoRtoS(const Call& call, const BlockBuilder& ctx) { TensorStructInfo tensor_sinfo = input_dtensor_sinfo->tensor_sinfo; const auto* attrs = call->attrs.as(); int num_workers = attrs->num_workers; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); auto input_shape = tensor_sinfo->GetShape(); TVM_FFI_ICHECK(input_shape.defined()) << "input tensor of redistribute_replica_to_shard should have defined shape."; diff --git a/src/relax/op/distributed/linear_algebra.cc b/src/relax/op/distributed/linear_algebra.cc index aeee041afb40..7f3b01005ee2 100644 --- a/src/relax/op/distributed/linear_algebra.cc +++ b/src/relax/op/distributed/linear_algebra.cc @@ -75,7 +75,7 @@ StructInfo InferDistStructInfoMatmul(const Call& call, const BlockBuilder& ctx) ffi::Optional> output_shape_prefix = InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); TVM_FFI_ICHECK(output_shape_prefix.defined()) << "Failed to infer output shape of Matmul"; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) { diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index f19c55b5d2ec..473ee7217ae1 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -89,7 +89,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { PrimExpr head_dim = q_shape->values[3]; PrimExpr num_keys = k_shape->values[1]; PrimExpr head_dim_value = v_shape->values[3]; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); auto diag_equal = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, ffi::String dim) { if (analyzer->CanProve(v1 != v2)) { ctx->ReportFatal(Diagnostic::Error(call) diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 8916e430822c..d497d2219741 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -102,7 +102,7 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); ffi::Array weight_OIW_shape = weight2OIW.ForwardShape(weight_shape.value()->values); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCW_shape[1]; PrimExpr input_channel_kernel = weight_OIW_shape[1]; if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { @@ -274,7 +274,7 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); ffi::Array weight_OIHW_shape = weight2OIHW.ForwardShape(weight_shape.value()->values); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCHW_shape[1]; PrimExpr input_channel_kernel = weight_OIHW_shape[1]; if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { @@ -490,7 +490,7 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); ffi::Array weight_OIDHW_shape = weight2OIDHW.ForwardShape(weight_shape.value()->values); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCDHW_shape[1]; PrimExpr input_channel_kernel = weight_OIDHW_shape[1]; if (analyzer->CanProve(input_channel_data != input_channel_kernel * attrs->groups)) { @@ -684,7 +684,7 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); ffi::Array weight_IOW_shape = weight2IOW.ForwardShape(weight_shape.value()->values); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCW_shape[1]; PrimExpr input_channel_kernel = weight_IOW_shape[0]; if (analyzer->CanProve(input_channel_data != input_channel_kernel)) { @@ -879,7 +879,7 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); ffi::Array weight_IOHW_shape = weight2IOHW.ForwardShape(weight_shape.value()->values); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCHW_shape[1]; PrimExpr input_channel_kernel = weight_IOHW_shape[0]; if (analyzer->CanProve(input_channel_data != input_channel_kernel)) { @@ -1115,7 +1115,7 @@ StructInfo InferStructInfoConv3dTranspose(const Call& call, const BlockBuilder& ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); ffi::Array weight_IODHW_shape = weight2IODHW.ForwardShape(weight_shape.value()->values); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr input_channel_data = data_NCDHW_shape[1]; PrimExpr input_channel_kernel = weight_IODHW_shape[0]; if (analyzer->CanProve(input_channel_data != input_channel_kernel)) { diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b6e2051a68f7..e5dbb1dc9cce 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -423,7 +423,7 @@ bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx, } } - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); for (int i = 1; i < static_cast(axis_lengths.size()); ++i) { for (int d = 0; d < n_axis; ++d) { if (analyzer->CanProve(axis_lengths[0][d] != axis_lengths[i][d])) { @@ -634,7 +634,7 @@ StructInfo InferStructInfoGroupNorm(const Call& call, const BlockBuilder& ctx) { ctx->ReportFatal(Diagnostic::Error(call) << op << " expects that data must be float, but got " << data_sinfo->dtype); } - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); const auto* data_shape = data_sinfo->shape.as(); if (data_shape != nullptr && channel_axis != -1 && analyzer->CanProve(floormod(data_shape->values[channel_axis], attrs->num_groups) != 0)) { @@ -745,7 +745,7 @@ StructInfo InferStructInfoInstanceNorm(const Call& call, const BlockBuilder& ctx } } const auto* data_shape = data_sinfo->shape.as(); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); for (int i = 1; i < static_cast(op->arguments.size()); ++i) { if (input_sinfo[i]->dtype != data_sinfo->dtype) { ctx->ReportFatal(Diagnostic::Error(call) @@ -929,7 +929,7 @@ StructInfo InferStructInfoCrossEntropy(const Call& call, const BlockBuilder& ctx } if (pred_shape_value.defined() && label_shape_value.defined()) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); for (size_t i = 0; i < pred_shape_value.value().size(); ++i) { if (analyzer->CanProve(pred_shape_value.value()[i] != label_shape_value.value()[i])) { ctx->ReportFatal(Diagnostic::Error(call) @@ -1067,7 +1067,7 @@ StructInfo InferStructInfoNLLLoss(const Call& call, const BlockBuilder& ctx) { << wgt_sinfo->ndim); } - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); ffi::Optional N; ffi::Optional C; ffi::Array output_shape; // N, d1, d2, ..., dk diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 2be119b788ec..dcf44eebba80 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -103,7 +103,7 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); std::vector out_NCW_shape; out_NCW_shape.resize(3); out_NCW_shape[0] = data_NCW_shape[0]; @@ -232,7 +232,7 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; @@ -394,7 +394,7 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { PrimExpr padding_w = IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); out_NCDHW_shape[0] = data_NCDHW_shape[0]; diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 8a28ab361af2..0b4c2c4b4148 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -51,7 +51,7 @@ bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { return pdiff[0] == 0; } tvm::arith::Analyzer ana; - diff = ana.Simplify(diff); + diff = ana->Simplify(diff); if (const int64_t* pdiff = tirx::as_const_int(diff)) { return pdiff[0] == 0; } diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index a019b87f3a2b..f6dd34ede6b0 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -109,7 +109,7 @@ ffi::Array GetTensorStructInfoFromTuple(const Call& call, cons return tensor_sinfo; } -BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer, +BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::AnalyzerObj* analyzer, const ffi::Array& x1_shape, const ffi::Array& x2_shape) { BinaryBroadcastShapeInferResult result; @@ -159,7 +159,7 @@ BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analy ffi::Optional> InferBinaryBroadcastShape( const Call& call, const BlockBuilder& ctx, const ffi::Array& x1_shape, const ffi::Array& x2_shape) { - auto infer_result = InferBinaryBroadcastShape(ctx->GetAnalyzer(), x1_shape, x2_shape); + auto infer_result = InferBinaryBroadcastShape(ctx->GetAnalyzer().get(), x1_shape, x2_shape); if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kConflict) { TVM_FFI_ICHECK(infer_result.message.has_value()); ctx->ReportFatal(Diagnostic::Error(call) @@ -223,7 +223,7 @@ bool CanProveLayoutTransform(const SLayout& input_layout, const SLayout& desired arith::Analyzer analyzer; for (size_t i = 0; i < shape.size(); ++i) { if (tirx::is_const_int(shape[i])) { - if (!analyzer.CanProveEqual(shape[i], back_shape[i])) { + if (!analyzer->CanProveEqual(shape[i], back_shape[i])) { can_prove = false; break; } diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 6f7de974cbe6..32e8da5ce997 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -413,7 +413,7 @@ struct BinaryBroadcastShapeInferResult { * \param x2_shape The shape of the second operand. * \return Inference status and broadcasted shape, or a conflict message. */ -BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* analyzer, +BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::AnalyzerObj* analyzer, const ffi::Array& x1_shape, const ffi::Array& x2_shape); diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 885f7c87257e..fdc096b09f28 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -375,7 +375,7 @@ StructInfo InferStructInfoArange(const Call& call, const BlockBuilder& ctx) { tvm::ceil(tvm::cast(tvm::DataType::Float(32), end - start) / step)); } arith::Analyzer analyzer; - num_elem = analyzer.Simplify(num_elem); + num_elem = analyzer->Simplify(num_elem); return TensorStructInfo(ShapeExpr({num_elem}), dtype); } @@ -421,12 +421,12 @@ StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ct PrimExpr window_size = get_prim_value(call->args[0], "window_size"); arith::Analyzer analyzer; - if (analyzer.CanProveLess(window_size, 1)) { + if (analyzer->CanProveLess(window_size, 1)) { ctx->ReportFatal(Diagnostic::Error(call) << "Hamming_window expects the window_size must be greater than zero but got " << window_size); } - window_size = analyzer.Simplify(window_size); + window_size = analyzer->Simplify(window_size); return TensorStructInfo(ShapeExpr({window_size}), dtype); } diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 79bedfdc485c..4b6f9551ac39 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -420,7 +420,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx PrimExpr output_dim = topi::GetLength(begin, end, strides_tuple[i], input_dim, attrs->assume_inbound); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); std::optional> context; if (attrs->assume_inbound) { context.emplace(analyzer, 0 <= begin && begin <= input_dim && 0 <= end && end <= input_dim); diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc index 6936fa04348b..fbf09905468e 100644 --- a/src/relax/op/tensor/linear_algebra.cc +++ b/src/relax/op/tensor/linear_algebra.cc @@ -134,7 +134,7 @@ StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(out_dtype, output_ndim); } - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) { diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 763e37ae6815..b42be1dedf67 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -108,7 +108,7 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx) return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype, data_sinfo->vdevice); } - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); ffi::Array old_shape_value = shape_sinfo->values.value(); ffi::Array tgt_shape_value = tgt_shape_sinfo->values.value(); int old_ndim = old_shape_value.size(); @@ -160,7 +160,7 @@ ffi::Optional> CheckConcatOutputShape( const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, int axis) { bool shape_unknown = false; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr concat_sum = [&]() { // For the specified axis, we compute the sum of shape value over each tensor. @@ -601,7 +601,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) << " index tensors, but data has only " << data_sinfo->ndim << " dimensions"); } - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); bool all_index_have_shape_value = true; std::vector> index_shapes; int max_index_ndim = 0; @@ -765,7 +765,7 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& } arith::Analyzer analyzer; - ffi::Array output_shape = index_map->MapShape(shape_sinfo->values.value(), &analyzer); + ffi::Array output_shape = index_map->MapShape(shape_sinfo->values.value(), analyzer); return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype, data_sinfo->vdevice); } @@ -991,7 +991,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, if (dim_to_infer != -1) { arith::Analyzer analyzer; PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); - array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + array_ref.Set(dim_to_infer, analyzer->Simplify(floordiv(old_shape_prod, new_shape_prod))); } return ShapeExpr(array_ref); } @@ -1403,7 +1403,7 @@ TVM_REGISTER_OP("relax.squeeze") void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, const ffi::Array& data_shape, const ffi::Array& target_shape) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); int data_ndim = data_shape.size(); int target_ndim = target_shape.size(); @@ -1458,7 +1458,7 @@ ffi::Optional> CheckStackOutputShape( const Call& call, const BlockBuilder& ctx, const std::vector>& shape_values, int axis) { bool shape_unknown = false; - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); // Stack requires all input tensors to have identical shapes for (int d = 0; d < static_cast(shape_values[0].size()); ++d) { @@ -1771,7 +1771,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } StructInfo InferStructInfoRepeat(const Call& call, const BlockBuilder& ctx) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); const auto* data_shape = data_sinfo->shape.as(); @@ -1896,7 +1896,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); const auto* data_shape = data_sinfo->shape.as(); @@ -2568,7 +2568,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder& ctx) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* indices_sinfo = GetStructInfoAs(call->args[1]); const auto* updates_sinfo = GetStructInfoAs(call->args[2]); @@ -2712,7 +2712,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { // `call->args` contains: [data, indices, updates] - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); TVM_FFI_ICHECK_EQ(call->args.size(), 3); const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* indices_sinfo = GetStructInfoAs(call->args[1]); @@ -2888,7 +2888,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } StructInfo InferStructInfoSliceScatter(const Call& call, const BlockBuilder& ctx) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); const auto* data_sinfo = GetStructInfoAs(call->args[0]); const auto* src_sinfo = GetStructInfoAs(call->args[1]); auto* attrs = call->attrs.as(); diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index febe4d521d3d..63fe2b77b765 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -115,7 +115,7 @@ StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBu PrimExpr batch = prob_shape->values[0]; PrimExpr n = uniform_sample_shape->values[0]; arith::Analyzer ana; - if (!ana.CanProveEqual(n, sample_indices_shape->values[0])) { + if (!ana->CanProveEqual(n, sample_indices_shape->values[0])) { ctx->ReportFatal(Diagnostic::Error(call) << "Multinomial_from_uniform op requires the input uniform_sample and " "sample_indices to have the same batch size. " diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc index 523c694ff5e8..b854b33288c9 100644 --- a/src/relax/op/tensor/ternary.cc +++ b/src/relax/op/tensor/ternary.cc @@ -85,7 +85,7 @@ StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { auto* s1 = t1->shape.as(); auto* s2 = t2->shape.as(); auto* s3 = t3->shape.as(); - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); if (s1 && s2 && s3) { ffi::Array output_shape; for (int i = 0; i < ndim; ++i) { diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index dbfe0d63aff5..b7f4c4d95ba6 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -269,7 +269,7 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { const auto* valid_count_shape = valid_count_sinfo->shape.as(); const auto* indices_shape = indices_sinfo->shape.as(); if (data_shape != nullptr) { - arith::Analyzer* analyzer = ctx->GetAnalyzer(); + arith::Analyzer analyzer = ctx->GetAnalyzer(); PrimExpr batch = data_shape->values[0]; PrimExpr num_anchors = data_shape->values[1]; if (valid_count_shape != nullptr && diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 012c8ce5b71a..e97e423e9b78 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -55,7 +55,7 @@ PrimExpr ProductDims(const ffi::Array& dims) { } ffi::Optional> InferBatchedMatmulBroadcastPrefix( - arith::Analyzer* analyzer, const ffi::Array& x1, const ffi::Array& x2) { + arith::AnalyzerObj* analyzer, const ffi::Array& x1, const ffi::Array& x2) { auto infer_result = InferBinaryBroadcastShape(analyzer, x1, x2); if (infer_result.status == BinaryBroadcastShapeInferResult::Status::kSuccess) { return infer_result.shape; @@ -244,15 +244,15 @@ std::tuple)>> auto prefix_b = GetBatchPrefix(shape_b); auto prefix_c = GetBatchPrefix(shape_c); - auto opt_prefix_ab = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, prefix_b); + auto opt_prefix_ab = InferBatchedMatmulBroadcastPrefix(analyzer.get(), prefix_a, prefix_b); if (!opt_prefix_ab) return expr; - auto opt_prefix_bc = InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_b, prefix_c); + auto opt_prefix_bc = InferBatchedMatmulBroadcastPrefix(analyzer.get(), prefix_b, prefix_c); if (!opt_prefix_bc) return expr; auto opt_prefix_outer_lhs = - InferBatchedMatmulBroadcastPrefix(&analyzer, opt_prefix_ab.value(), prefix_c); + InferBatchedMatmulBroadcastPrefix(analyzer.get(), opt_prefix_ab.value(), prefix_c); if (!opt_prefix_outer_lhs) return expr; auto opt_prefix_outer_rhs = - InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, opt_prefix_bc.value()); + InferBatchedMatmulBroadcastPrefix(analyzer.get(), prefix_a, opt_prefix_bc.value()); if (!opt_prefix_outer_rhs) return expr; PrimExpr batch_ab = ProductDims(opt_prefix_ab.value()); @@ -275,17 +275,18 @@ std::tuple)>> PrimExpr ops_with_rhs_first = batch_bc * size_R * size_M * size_B + batch_outer_rhs * size_N * size_R * size_B; - analyzer.rewrite_simplify.SetEnabledExtensions(static_cast( - analyzer.rewrite_simplify.GetEnabledExtensions() | - arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum)); - With func_attr_constraint(&analyzer, symbolic_var_constraints); + analyzer->rewrite_simplify.SetEnabledExtensions( + static_cast( + analyzer->rewrite_simplify.GetEnabledExtensions() | + arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum)); + With func_attr_constraint(analyzer, symbolic_var_constraints); With analyzer_constraint( - &analyzer, batch_ab > 0 && batch_bc > 0 && batch_outer_lhs > 0 && batch_outer_rhs > 0 && - size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); + analyzer, batch_ab > 0 && batch_bc > 0 && batch_outer_lhs > 0 && batch_outer_rhs > 0 && + size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0); - if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) { + if (analyzer->CanProve(ops_with_lhs_first < ops_with_rhs_first)) { return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, DataType::Void()); - } else if (analyzer.CanProve(ops_with_rhs_first < ops_with_lhs_first)) { + } else if (analyzer->CanProve(ops_with_rhs_first < ops_with_lhs_first)) { return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void()); } diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 09492a5869a2..16e492a80d0a 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -68,9 +68,9 @@ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { ffi::Array input_shape = GetShapeFromTensor(expr); ffi::Array initial_ranges = ConstructRangeFromShape(input_shape); arith::Analyzer analyzer; - auto [inverse, padding_predicate] = transform.NonSurjectiveInverse(initial_ranges, &analyzer); + auto [inverse, padding_predicate] = transform.NonSurjectiveInverse(initial_ranges, analyzer); (void)inverse; // to avoid unused variable warning; - if (!analyzer.CanProve(!padding_predicate)) return false; + if (!analyzer->CanProve(!padding_predicate)) return false; return true; } @@ -256,7 +256,7 @@ class AlterOpImplMutator : public ExprMutator { ffi::Array initial_ranges = ConstructRangeFromShape(old_shape); arith::Analyzer analyzer; auto [inverse_index_map, padding_predicate] = - index_map.NonSurjectiveInverse(initial_ranges, &analyzer); + index_map.NonSurjectiveInverse(initial_ranges, analyzer); if (tirx::is_zero(padding_predicate)) { return TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator); @@ -352,7 +352,7 @@ class AlterOpImplMutator : public ExprMutator { if (transform.get() == nullptr) return tensor_sinfo; auto shape = GetShapeFromTensorStructInfo(tensor_sinfo); arith::Analyzer analyzer; - auto new_shape = transform->MapShape(shape, &analyzer); + auto new_shape = transform->MapShape(shape, analyzer); if (tensor_sinfo->vdevice.defined()) { return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype, tensor_sinfo->vdevice.value()); diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index ff5ad73380f0..c7b4cc5e9ba0 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -33,7 +33,8 @@ namespace tvm { namespace relax { void MatchSymbolicVar(const Expr& arg, const Expr& constant, - ffi::Map* symbolic_var_map, arith::Analyzer* analyzer_) { + ffi::Map* symbolic_var_map, + arith::AnalyzerObj* analyzer_) { auto opt_arg_sinfo = MatchStructInfo(arg); TVM_FFI_ICHECK(opt_arg_sinfo) << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " @@ -145,7 +146,7 @@ std::tuple, ffi::Map> NormalizeBindings } arith::Analyzer analyzer; - ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, analyzer); // for (const auto& [bind_param, bind_expr] : relax_var_remap) { // MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer); diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index a46b5c5b5546..d55dacc0ff26 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -125,7 +125,7 @@ ffi::TypedFunction(ffi::Map, ffi::Map(rhs_shapes[ind].size()), rhs_dim); // -2 for reduction and concat axes for (size_t i = 0; i < rhs_dim - 2; ++i) { - if (!ana.CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) { + if (!ana->CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) { return false; } } diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index d0089734ad24..9e4c11ee707a 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -41,7 +41,7 @@ namespace tirx { */ class SymbolicMatcher : ExprFunctor { public: - explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map* var_remap) + explicit SymbolicMatcher(arith::AnalyzerObj* analyzer, ffi::Map* var_remap) : analyzer_(analyzer), var_remap_(var_remap) {} void Match(const ffi::Array& params, const ffi::Array& args) { @@ -153,7 +153,7 @@ class SymbolicMatcher : ExprFunctor* var_remap_; PrimExpr must_prove_ = const_true(); }; @@ -1091,7 +1091,7 @@ class FusedTIRConstructor : public ExprVisitor { /*! \brief The map from symbolic var to its corresponding var in the fused function */ tirx::SymbolicMatcher symbolic_var_matcher = - tirx::SymbolicMatcher(&analyzer, &symbolic_var_remap); + tirx::SymbolicMatcher(analyzer.get(), &symbolic_var_remap); }; /*! \brief The IRModule */ diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index 9c1639d28c6e..daf4c6e2fd9c 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -128,7 +128,7 @@ std::optional AnalyzeCallee(Function func) { old_binding.Set(old_relax_params[i], old_args[i]); } arith::Analyzer analyzer; - auto tir_binding = InferSymbolicVarMap(old_binding, &analyzer); + auto tir_binding = InferSymbolicVarMap(old_binding, analyzer); for (const auto& tir_var : free_tir_vars) { new_args.push_back(PrimValue(tir_binding.at(tir_var))); diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index b54544b00082..46ccdd82dfa4 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -144,7 +144,7 @@ class DataflowReshapeRewriter : public ExprMutator { }; auto inp_count = product(inp_sinfo->GetShape().value()); auto res_count = product(res_sinfo->GetShape().value()); - if (!arith::Analyzer().CanProveEqual(inp_count, res_count)) { + if (!arith::Analyzer()->CanProveEqual(inp_count, res_count)) { return false; } diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 45c0e61a25f1..b73faa39007e 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -105,7 +105,7 @@ class ForMatcher : public TensorizeComparator { if (lhs->IsInstance() || lhs->IsInstance()) { ffi::Optional value = QueryEvaluatedSymbols(ffi::GetRef(op)); if (value.defined()) { - if (!analyzer_.CanProveEqual(lhs, value.value())) return false; + if (!analyzer_->CanProveEqual(lhs, value.value())) return false; } else { evaluated_symbols.back()[ffi::GetRef(op)] = lhs; } diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index b8b6ba30d25b..f0b27643b8e1 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -195,7 +195,7 @@ using Tokens = NestedMsg; */ class TokenAllocatorMixed { public: - explicit TokenAllocatorMixed(arith::Analyzer* analyzer) : analyzer_(analyzer) {} + explicit TokenAllocatorMixed(arith::AnalyzerObj* analyzer) : analyzer_(analyzer) {} /*! * \brief Request a storage token from the available token pool for a @@ -314,7 +314,7 @@ class TokenAllocatorMixed { }; /*! \brief The arithmetic analyzer. */ - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; /*! \brief A constant scale representing the token search range. */ const int match_range_{16}; /*! \brief The pool of available storage tokens for each storage scope and dtype. */ @@ -408,7 +408,7 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { * \param ana The analyzer which contains the TIR var upper bounds. * \param dom_map The domain map of the TIR variables. */ -void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, +void SetTIRVarRangeConstraints(Function func, arith::AnalyzerObj* ana, ffi::Map* dom_map) { // Use the attribute-annotated TIR var bounds as the TIR var values for // memory planning. @@ -468,7 +468,7 @@ void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, * \return The upper-bounded shape. When a dimension's upper bound * cannot be determined, we keep the dimension unchanged. */ -ffi::Array GetUpperBoundShape(ffi::Array shape, arith::Analyzer* ana, +ffi::Array GetUpperBoundShape(ffi::Array shape, arith::AnalyzerObj* ana, const ffi::Map& dom_map) { // Use the upper bounds of TIR vars as their values. ffi::Array upper_bounded_shape; @@ -517,7 +517,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { * \return The mapping from each Expr to the token it uses. */ static std::unordered_map Initialize(const IRModule& mod, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { StorageAllocatorInit initializer(mod, analyzer); for (auto it : mod->functions) { @@ -533,7 +533,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { private: using ExprVisitor::VisitExpr_; - explicit StorageAllocatorInit(const IRModule& ctx_mod, arith::Analyzer* analyzer) + explicit StorageAllocatorInit(const IRModule& ctx_mod, arith::AnalyzerObj* analyzer) : ctx_mod_(ctx_mod), analyzer_(analyzer) {} void VisitExpr_(const FunctionNode* func) final { @@ -724,7 +724,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { */ const IRModule& ctx_mod_; /*! \brief The arithmetic analyzer. */ - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; /*! \brief The domain map of dynamic TIR variables for analysis. */ ffi::Map dom_map_; /*! \brief The mapping from each token to the binding block where it is created. */ @@ -750,7 +750,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { class StorageAllocator : public StorageAllocatorBaseVisitor { public: explicit StorageAllocator(std::unordered_map token_map, - arith::Analyzer* analyzer) + arith::AnalyzerObj* analyzer) : allocator_(analyzer) { this->token_map_ = std::move(token_map); } @@ -902,7 +902,7 @@ class StorageAllocationRewriter : public ExprMutator { plan_dynamic_output_ = static_cast( func_->GetAttr(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value); if (plan_dynamic_output_) { - SetTIRVarRangeConstraints(ffi::GetRef(func_), &ana_, &dom_map_); + SetTIRVarRangeConstraints(ffi::GetRef(func_), ana_.get(), &dom_map_); } token2storage_var_.clear(); Function func = Downcast(this->VisitExpr_(func_)); @@ -966,7 +966,8 @@ class StorageAllocationRewriter : public ExprMutator { TVM_FFI_ICHECK_NOTNULL(sinfo); const auto* shape = sinfo->shape.as(); TVM_FFI_ICHECK_NOTNULL(shape); - ffi::Array upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_, dom_map_); + ffi::Array upper_bounded_shape = + GetUpperBoundShape(shape->values, ana_.get(), dom_map_); if (!IsStaticShape(shape->values)) { TVM_FFI_ICHECK(!sinfo->IsUnknownDtype()); TVM_FFI_ICHECK_EQ(sinfo->dtype, Downcast(call->args[1])->value); @@ -1014,9 +1015,9 @@ IRModule StaticPlanBlockMemory(IRModule mod) { // Step 1. Initialize. std::unordered_map token_map = - StorageAllocatorInit::Initialize(mod, &ana); + StorageAllocatorInit::Initialize(mod, ana.get()); // Step 2. Collect the memory allocation info. - StorageAllocator allocator(std::move(token_map), &ana); + StorageAllocator allocator(std::move(token_map), ana.get()); allocator.Allocate(mod); // Step 3. Rewrite the function. StorageAllocationRewriter rewriter(std::move(mod), // diff --git a/src/relax/utils.cc b/src/relax/utils.cc index 81e810275105..2155824bda38 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -81,7 +81,7 @@ class ExprBinder : public ExprMutator { auto new_expr = tirx::Substitute(expr, symbolic_var_map_); if (!expr.same_as(new_expr)) { arith::Analyzer analyzer; - new_expr = analyzer.Simplify(new_expr); + new_expr = analyzer->Simplify(new_expr); } return new_expr; } @@ -109,7 +109,9 @@ StructInfo Bind(const StructInfo& sinfo, } tvm::ffi::Map InferSymbolicVarMap( - const tvm::ffi::Map& relax_var_remap, arith::Analyzer* analyzer) { + const tvm::ffi::Map& relax_var_remap, + const arith::Analyzer& analyzer) { + (void)analyzer; tvm::ffi::Map tir_var_remap; auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape, diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index 9f3e77a2e88e..d77e715db1b6 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -119,7 +119,7 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const GENode* op) override { return TResult(); } int64_t GetLoopExtent(const ForNode* node, const arith::Analyzer& ana) { - int64_t bound = ana.const_int_bound(node->extent)->max_value; + int64_t bound = ana->const_int_bound(node->extent)->max_value; if (bound == arith::ConstIntBound::kPosInf) { return 1; // Analyzer could not determine a valid bound, use 1 instead. } else { @@ -158,7 +158,7 @@ class FlopEstimator : private ExprFunctor, return result; } TResult VisitStmt_(const ForNode* loop) override { - ana.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + ana->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); const auto int_imm = GetLoopExtent(loop, ana); TResult result = VisitStmt(loop->body); result *= int_imm; diff --git a/src/s_tir/analysis/identify_memcpy.cc b/src/s_tir/analysis/identify_memcpy.cc index 11cdc2487548..e008f7e7ebc3 100644 --- a/src/s_tir/analysis/identify_memcpy.cc +++ b/src/s_tir/analysis/identify_memcpy.cc @@ -44,7 +44,7 @@ namespace s_tir { using namespace tvm::tirx; std::variant IdentifyMemCpyImpl(const For& loop, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { ffi::Map loop_intervals; ffi::Map loop_ranges; PrimExpr total_loop_iterations = 1; @@ -106,8 +106,9 @@ std::variant IdentifyMemCpyImpl(const For& loop, // for i in T.serial(16): // B[i] = A[T.abs(i-8)] + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, const_true(), - arith::IterMapLevel::Bijective, analyzer); + arith::IterMapLevel::Bijective, analyzer_ref); if (src_iter_map->errors.size()) { return static_cast(std::stringstream() << "arith::DetectIterMap(src) returned " @@ -117,7 +118,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, .str(); } auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, const_true(), - arith::IterMapLevel::Bijective, analyzer); + arith::IterMapLevel::Bijective, analyzer_ref); if (dst_iter_map->errors.size()) { return static_cast(std::stringstream() << "arith::DetectIterMap(dst) returned " @@ -276,8 +277,8 @@ std::variant IdentifyMemCpyImpl(const For& loop, return MemCpyDetails{src_region, dst_region}; } -std::optional IdentifyMemCpy(const For& loop, arith::Analyzer* analyzer) { - auto result = IdentifyMemCpyImpl(loop, analyzer); +std::optional IdentifyMemCpy(const For& loop, const arith::Analyzer& analyzer) { + auto result = IdentifyMemCpyImpl(loop, analyzer.get()); if (auto* ptr = std::get_if(&result)) { return *ptr; } else { @@ -299,7 +300,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { using IRVisitorWithAnalyzer::VisitStmt_; void VisitStmt_(const ForNode* op) override { For loop = ffi::GetRef(op); - auto result = IdentifyMemCpyImpl(loop, &(Visitor::analyzer_)); + auto result = IdentifyMemCpyImpl(loop, Visitor::analyzer_.get()); if (auto* ptr = std::get_if(&result)) { output->push_back(ffi::Array{ptr->source, ptr->dest}); } else if (auto* ptr = std::get_if(&result)) { diff --git a/src/s_tir/analysis/oob_checker.cc b/src/s_tir/analysis/oob_checker.cc index 300f61327b1a..a37c8387731e 100644 --- a/src/s_tir/analysis/oob_checker.cc +++ b/src/s_tir/analysis/oob_checker.cc @@ -89,8 +89,8 @@ class OOBCheckerVisitor final : public arith::IRVisitorWithAnalyzer { template void CheckBounds(const T* node, size_t i) { - auto ind_bounds = analyzer_.int_set(node->indices[i]); - auto shape_bounds = analyzer_.int_set(node->buffer->shape[i]); + auto ind_bounds = analyzer_->int_set(node->indices[i]); + auto shape_bounds = analyzer_->int_set(node->buffer->shape[i]); // We would expect that // `analyzer_.CanProve(node->indices[i] < 0 || node->indices[i] >= node->buffer->shape[i])` // would be the way to check if any out of bounds access occurs here, but `CanProve` checks if @@ -102,8 +102,8 @@ class OOBCheckerVisitor final : public arith::IRVisitorWithAnalyzer { // has the problem that some valid access patterns maybe be valid but not provably valid. We // prefer that this analysis is conservative and only shows errors that are provable. This leads // us to the following check: are the bounds of the index outside the bounds of the shape. - if (analyzer_.CanProve(ind_bounds.max() >= shape_bounds.min()) || - analyzer_.CanProve(ind_bounds.min() < 0)) { + if (analyzer_->CanProve(ind_bounds.max() >= shape_bounds.min()) || + analyzer_->CanProve(ind_bounds.min() < 0)) { errors.push_back({node->buffer, i, node->indices[i], ind_bounds, shape_bounds}); } } diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 918b9d2815d8..0eddf22d8506 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -346,7 +346,7 @@ ffi::Array BlockReadWriteDetector::CollectRegions( TVM_FFI_ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - if (range.CanProveSinglePoint(&ana_)) { + if (range.CanProveSinglePoint(ana_)) { PrimExpr min = range.min(); region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); } else { diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index f52a6d7148c6..ef0fe72acd28 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -46,7 +46,7 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { public: static PrimFunc Inject(PrimFunc func) { arith::Analyzer ana; - auto pass = TextureAllocInjector(&ana); + auto pass = TextureAllocInjector(ana.get()); auto writer = func.CopyOnWrite(); pass.MarkBufferMapShapes(func); writer->body = pass.VisitStmt(func->body); @@ -59,7 +59,7 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt; using IRMutatorWithAnalyzer::VisitStmt_; - explicit TextureAllocInjector(arith::Analyzer* ana) : IRMutatorWithAnalyzer(ana) {} + explicit TextureAllocInjector(arith::AnalyzerObj* ana) : IRMutatorWithAnalyzer(ana) {} Stmt VisitStmt_(const AllocBufferNode* op) final { Stmt stmt = StmtExprMutator::VisitStmt_(op); diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc index 34682315c7e8..787386c8ccb9 100644 --- a/src/s_tir/data_layout.cc +++ b/src/s_tir/data_layout.cc @@ -417,7 +417,7 @@ inline bool GetStoreRule(ffi::Array* index_rule, ffi::Array* factor = factor * dst_unpacked_axes[k]->dom->extent.as().value(); } } - ana.Simplify(factor); + ana->Simplify(factor); index_rule->push_back(factor); shape_rule->push_back(factor); } @@ -450,7 +450,7 @@ inline ffi::Array TransformIndex(const ffi::Array& src_index bind_map[src_axis[i]->var.get()] = src_index[i]; } for (PrimExpr rule : transform_rule) { - result.push_back(ana.Simplify(tirx::Substitute(rule, bind_map))); + result.push_back(ana->Simplify(tirx::Substitute(rule, bind_map))); } return result; } @@ -517,7 +517,7 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape if (layout.size() != 1 || !SLayoutAxis::Get(layout[0]).IsPrimal()) { result.push_back(axis->dom->extent); } else { - result.push_back(ana.Simplify(tirx::Substitute(rule, bind_map))); + result.push_back(ana->Simplify(tirx::Substitute(rule, bind_map))); } } diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index c4fad1e7fb37..b567ffa4eb1f 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -62,7 +62,7 @@ namespace utils { * \param analyzer The analyzer * \return The shape of the buffer */ -std::vector GetBufferShape(const Buffer& buffer, arith::Analyzer* analyzer) { +std::vector GetBufferShape(const Buffer& buffer, arith::AnalyzerObj* analyzer) { int ndim = buffer->shape.size(); std::vector result; result.reserve(ndim); @@ -121,7 +121,7 @@ int64_t FirstLoopExtent(const ForVec& loops, int64_t default_value) { * \return The relaxed and unioned region */ IntVec RelaxAndUnion(const std::vector& multi_indices, int64_t* numel, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { *numel = 1; if (multi_indices.empty()) { return {}; @@ -737,7 +737,7 @@ struct Feature { static void Pad(std::vector* v) { v->insert(v->end(), 18, 0.0); } - void SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer); + void SetStride(const LoopNest& loop_nest, arith::AnalyzerObj* analyzer); void SetReuse(const LoopNest& loop_nest, // int64_t top_loop_touch_bytes, // @@ -766,14 +766,14 @@ struct Feature { explicit Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_t cache_line_bytes, IntVec* for_touched_bytes, - ForBufferMap* buffer_touched_under_loop, arith::Analyzer* analyzer); + ForBufferMap* buffer_touched_under_loop, arith::AnalyzerObj* analyzer); void Init(const BufferStoreNode* store, int n_loops); void SetRegion(const LoopNest& loop_nest, // IntVec* for_touched_bytes, // ForBufferMap* buffer_touched_under_loop, // - arith::Analyzer* analyzer); + arith::AnalyzerObj* analyzer); std::vector sub_features; }; @@ -820,7 +820,7 @@ void Feature::Init(const BufferStoreNode* store, int n_loops) { void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes, ForBufferMap* buffer_touched_under_loop, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { int n_loops = loop_nest.loops.size(); const std::vector& loops = loop_nest.loops; // Step 1. Initialize and bind all the loop variables to a constant @@ -858,7 +858,7 @@ void Feature::SetRegion(const LoopNest& loop_nest, IntVec* for_touched_bytes, } } -void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::Analyzer* analyzer) { +void Feature::SubFeature::SetStride(const LoopNest& loop_nest, arith::AnalyzerObj* analyzer) { int n_loops = loop_nest.loops.size(); const std::vector& loops = loop_nest.loops; // For each buffer, we find the loop stride on it @@ -1009,7 +1009,7 @@ void Feature::SubFeature::SetFeature(const LoopNest& loop_nest, int64_t cache_li Feature::Feature(const BufferStoreNode* store, const LoopNest& loop_nest, int64_t cache_line_bytes, IntVec* for_touched_bytes, ForBufferMap* buffer_touched_under_loop, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { int n_loops = loop_nest.loops.size(); // Step 0. Initialize data structures this->Init(store, n_loops); @@ -1155,7 +1155,7 @@ struct Feature { Feature() = default; - explicit Feature(const LoopNest& loop_nest, const Buffer& buffer, arith::Analyzer* analyzer) { + explicit Feature(const LoopNest& loop_nest, const Buffer& buffer, arith::AnalyzerObj* analyzer) { std::vector shape = utils::GetBufferShape(buffer, analyzer); int64_t numel = 1; for (int64_t x : shape) { @@ -1324,7 +1324,7 @@ class PerStoreFeatureCollector : private StmtVisitor { feature.group1 = std::make_unique(store, loop_nest_, is_gpu_); feature.group2 = std::make_unique(store, loop_nest_, cache_line_bytes_, &for_touched_bytes_, - &buffer_touched_under_loop_, &analyzer_); + &buffer_touched_under_loop_, analyzer_.get()); feature.group3 = std::make_unique(arith_intensity_curve_num_samples_, loop_nest_, for_touched_bytes_, feature.group1->arith_ops); @@ -1340,7 +1340,7 @@ class PerStoreFeatureCollector : private StmtVisitor { void HandleBufferAlloc(const Buffer& buffer) { Feature& feature = buffer_features_[buffer.get()]; - feature.group4 = std::make_unique(loop_nest_, buffer, &analyzer_); + feature.group4 = std::make_unique(loop_nest_, buffer, analyzer_.get()); } explicit PerStoreFeatureCollector(bool is_gpu, int64_t cache_line_bytes, diff --git a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 6e1f195e75b3..cfa7393203a0 100644 --- a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -92,7 +92,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { // Use DetectIterMap to detect whether store index is non-contiguous. arith::Analyzer analyzer; auto store_iter_map = DetectIterMap(store_index, input_iters, 1, - arith::IterMapLevel::Surjective, &analyzer, false); + arith::IterMapLevel::Surjective, analyzer, false); if (!store_iter_map->errors.empty()) { found_ = true; } @@ -102,7 +102,7 @@ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { // Use DetectIterMap to detect whether load index is non-contiguous. auto load_iter_map = DetectIterMap(load_index, input_iters, 1, - arith::IterMapLevel::Surjective, &analyzer, false); + arith::IterMapLevel::Surjective, analyzer, false); if (!load_iter_map->errors.empty()) { found_ = true; } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc index d53e53969ad0..cb0504be0c4d 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc @@ -74,7 +74,7 @@ class BufferReadPosCollector : public StmtExprVisitor { /*indices=*/subst_indices, // /*loops=*/loop_stack_, // /*predicate=*/cur_realize_->predicate, // - /*analyzer=*/&analyzer_); + /*analyzer=*/analyzer_.get()); int buffer_index = GetReadBufferIndex(cur_realize_->block, buffer); TVM_FFI_ICHECK(buffer_index != -1); buffer_loc_ = std::make_pair(cur_realize_->block, buffer_index); diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index b77355ee3bb2..d8cb2f853ea2 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -222,7 +222,7 @@ void AdjustParallelVectorize(const Schedule& sch, const SBlockRV& block_rv, const auto* var = loop_sref->StmtAs(); arith::Analyzer analyzer; for (int i = access->region.size() - 1; i >= 0; i--) { - PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); + PrimExpr idx = analyzer->Simplify(Substitute(access->region[i]->min, binding_map)); int64_t coef = StrideExtractor::Extract(idx, var->loop_var); if (coef != 0) { stride = coef * buffer_stride; diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 1dee2fe1d007..d9f49538b268 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -95,7 +95,7 @@ MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, SBlockRV block_rv const size_t innermost_axis = block_node->writes[0]->region.size() - 1; const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis]; - if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) { + if (!arith::Analyzer()->CanProve(loop->loop_var == innermost_iter_value)) { // If this is not the innermost spatial loop, split the loop in the normal way. return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles); } else { diff --git a/src/s_tir/schedule/analysis.h b/src/s_tir/schedule/analysis.h index 67df49ac75d3..27454e5e6434 100644 --- a/src/s_tir/schedule/analysis.h +++ b/src/s_tir/schedule/analysis.h @@ -81,7 +81,7 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); * \param analyzer The analyzer to be bound */ void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, - arith::Analyzer* analyzer); + arith::AnalyzerObj* analyzer); /******** Scope ********/ /*! @@ -232,7 +232,7 @@ bool IsWriteCache(const StmtSRef& block_sref); * \return A boolean flag indicating if the binding is affine */ bool IsAffineBinding(const SBlockRealize& realize, const ffi::Map& loop_var_ranges, - arith::Analyzer* analyzer); + arith::AnalyzerObj* analyzer); /*! * \brief Check whether a block has an affine binding using the cached flag, and throw an exception @@ -298,7 +298,7 @@ bool GetVarsTouchedByBlockIters(const SBlockRealize& block_realize, * \throw ScheduleError If the loop doesn't starts with zero. */ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, - arith::Analyzer* analyzer); + arith::AnalyzerObj* analyzer); /*! * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, @@ -602,7 +602,7 @@ bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, */ ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array& indices, const ffi::Array& loops, const PrimExpr& predicate, - arith::Analyzer* analyzer); + arith::AnalyzerObj* analyzer); /*! * \brief Checks if the given AST contains the specific operators @@ -706,7 +706,7 @@ ffi::Array AnalyzeRegionUpperBound(const BufferRegion& region, const PrimExpr& predicate, const StmtSRef& dom_low_inclusive, const StmtSRef& dom_high_exclusive, - arith::Analyzer* analyzer); + arith::AnalyzerObj* analyzer); /*! * \brief Analyze the buffer region under the sref tree path [dom_low_inclusive, dom_high_exclusive) @@ -722,7 +722,7 @@ ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, const PrimExpr& predicate, const StmtSRef& dom_low_inclusive, const StmtSRef& dom_high_exclusive, - arith::Analyzer* analyzer); + arith::AnalyzerObj* analyzer); /*! * \brief Simplify non-trivial expressions @@ -734,7 +734,7 @@ ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, * simplified to constant values for further scheduling and analysis because simplifing away the * block iters may result in loss of information for further analysis. */ -PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer); +PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::AnalyzerObj* analyzer); /*! \brief Necessary information used for tensorization */ class TensorizeInfoNode : public ffi::Object { diff --git a/src/s_tir/schedule/analysis/analysis.cc b/src/s_tir/schedule/analysis/analysis.cc index 3446d1fa639f..52e5cfe287d1 100644 --- a/src/s_tir/schedule/analysis/analysis.cc +++ b/src/s_tir/schedule/analysis/analysis.cc @@ -555,7 +555,7 @@ bool IsWriteCache(const StmtSRef& block_sref) { /******** Binding ********/ bool IsAffineBinding(const SBlockRealize& realize, const ffi::Map& loop_var_ranges, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { if (loop_var_ranges.empty()) { return true; } @@ -564,7 +564,7 @@ bool IsAffineBinding(const SBlockRealize& realize, const ffi::Map& l /*input_iters=*/loop_var_ranges, /*predicate=*/realize->predicate, /*check_level=*/arith::IterMapLevel::Surjective, - /*analyzer=*/analyzer, + /*analyzer=*/ffi::GetRef(analyzer), /*simplify_trivial_iterators=*/false); if (res->indices.empty()) { return false; @@ -626,7 +626,7 @@ void CheckPartialAffineBinding(const ScheduleState& self, SBlock block, arith::Analyzer analyzer; ffi::Map dom_map = LoopDomainOfSRefTreePath(ffi::GetRef(block_sref->parent), high_exclusive); - if (IsAffineBinding(GetSBlockRealize(self, block_sref), dom_map, &analyzer)) { + if (IsAffineBinding(GetSBlockRealize(self, block_sref), dom_map, analyzer.get())) { return; } } @@ -746,7 +746,7 @@ bool GetVarsTouchedByBlockIters(const SBlockRealize& block_realize, /******** Loop properties ********/ void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { class LoopNotStartWithZeroError : public ScheduleError { public: explicit LoopNotStartWithZeroError(IRModule mod, For loop) @@ -1304,7 +1304,7 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { } void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { while (sref->parent != nullptr) { sref = sref->parent; } @@ -1698,7 +1698,7 @@ bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // } } -PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) { +PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::AnalyzerObj* analyzer) { auto simplified = analyzer->Simplify(expr); if (simplified->IsInstance()) { return expr; @@ -1725,7 +1725,7 @@ struct TensorIntrinDescInfo { * \param desc_func The description PrimFunc * \return The auxilary information */ -TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, +TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::AnalyzerObj* analyzer, const PrimFunc& desc_func) { TensorIntrinDescInfo info; const auto* desc_scope_realize = desc_func->body.as(); @@ -1761,7 +1761,7 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& arith::Analyzer analyzer; const tirx::SBlockRealize& block = GetSBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars - TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); + TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(analyzer.get(), desc_func); // Step 2. Collect loops from block_sref const tirx::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); TVM_SREF_TO_SBLOCK(scope_sref); @@ -1775,7 +1775,7 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& } block_loops.push_back(loop); block_loop_vars.insert(loop->loop_var.get()); - if (!analyzer.CanProve(loop->min == 0)) { + if (!analyzer->CanProve(loop->min == 0)) { return std::nullopt; } } @@ -1826,7 +1826,7 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& IterVarType iter_type_desc = iter_types_desc[i_desc]; for (int i = 0, n = desc_loops.size(); i < n; ++i) { // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars - PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + PrimExpr residual = analyzer->Simplify(desc_bind - desc_loops[i]->loop_var); if (!UsesVar(residual, [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { desc_loop = desc_loops[i]; @@ -1861,7 +1861,7 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& // Skip i-th loop if it has already been mapped if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue; - PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + PrimExpr residual = analyzer->Simplify(block_bind - block_loops[i]->loop_var); if (UsesVar(residual, [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { continue; @@ -1930,7 +1930,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { class AutoTensorizeMappingProposer { public: static ffi::Array ProposeMappings(const AutoTensorizeComparator* extractor, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { AutoTensorizeMappingProposer proposer(extractor, analyzer); proposer.CollectFeasibleSet(); return proposer.ProposeAllFuseMapping(); @@ -1938,7 +1938,7 @@ class AutoTensorizeMappingProposer { private: explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator* extractor, - arith::Analyzer* analyzer) + arith::AnalyzerObj* analyzer) : extractor_(extractor), analyzer_(analyzer) {} using VarSet = std::unordered_set; @@ -2102,7 +2102,7 @@ class AutoTensorizeMappingProposer { // tensor intrin. const AutoTensorizeComparator* extractor_; // The arithmetic analyzer. - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; /*! \brief Potential mappings on RHS for each variable on LHS */ std::unordered_map lhs_feasible_vars_; }; @@ -2115,7 +2115,7 @@ bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tirx::StmtSR // Ignore the scope of buffers when comparing, since we can do cache_read/write const SBlockRealize& block = GetSBlockRealize(state, block_sref); arith::Analyzer analyzer; - auto desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); + auto desc_info = ExtractTensorIntrinDescInfo(analyzer.get(), desc_func); return extractor->VisitStmt(block->block, desc_info.desc_block->block); } @@ -2135,7 +2135,7 @@ ffi::Optional GetAutoTensorizeMappingInfo( } arith::Analyzer analyzer; ffi::Array mappings = - AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); + AutoTensorizeMappingProposer::ProposeMappings(&extractor, analyzer.get()); if (mappings.empty()) { return std::nullopt; } diff --git a/src/s_tir/schedule/analysis/layout.cc b/src/s_tir/schedule/analysis/layout.cc index 035faee48436..d99f99bd847b 100644 --- a/src/s_tir/schedule/analysis/layout.cc +++ b/src/s_tir/schedule/analysis/layout.cc @@ -80,9 +80,10 @@ class SplitExprCollector { const ffi::Map& input_iters, // const PrimExpr& predicate, // arith::IterMapLevel check_level, // - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, - predicate, check_level, analyzer); + predicate, check_level, analyzer_ref); const auto& iter_sum_exprs = res->indices; if (iter_sum_exprs.empty()) { return {}; @@ -130,7 +131,7 @@ class SplitExprCollector { ffi::Optional SuggestIndexMap(const Buffer& buffer, const ffi::Array& indices, const ffi::Array& loops, const PrimExpr& predicate, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { int ndim = buffer->shape.size(); int n_loops = loops.size(); // Step 1. Collect the domains and indices of loop variables @@ -250,7 +251,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { "s_tir.schedule.SuggestIndexMap", [](Buffer buffer, ffi::Array indices, ffi::Array loops, PrimExpr predicate) { arith::Analyzer analyzer; - return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); + return SuggestIndexMap(buffer, indices, loops, predicate, analyzer.get()); }); } diff --git a/src/s_tir/schedule/analysis/reducer.cc b/src/s_tir/schedule/analysis/reducer.cc index 74e34aaef634..d6bb5c903492 100644 --- a/src/s_tir/schedule/analysis/reducer.cc +++ b/src/s_tir/schedule/analysis/reducer.cc @@ -490,11 +490,11 @@ std::pair, ffi::Array> GetInitValuesAndUpdates ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12); } for (int d = 0; d < n_dim; ++d) { - if (!ana.CanProveEqual(updates[i]->buffer->shape[d], expected_shape[d])) { + if (!ana->CanProveEqual(updates[i]->buffer->shape[d], expected_shape[d])) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/11); } - if (!ana.CanProveEqual(inits[i]->indices[d], expected_indices[d]) || - !ana.CanProveEqual(updates[i]->indices[d], expected_indices[d])) { + if (!ana->CanProveEqual(inits[i]->indices[d], expected_indices[d]) || + !ana->CanProveEqual(updates[i]->indices[d], expected_indices[d])) { ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12); } } diff --git a/src/s_tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc index e69b5b2f47e7..8a9d0aab370c 100644 --- a/src/s_tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -33,7 +33,7 @@ Schedule Schedule::Concrete(IRModule mod, LinearCongruentialEngine::TRandState s n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; - n->analyzer_ = std::make_unique(); + n->analyzer_ = arith::Analyzer(); n->Seed(seed); GlobalVar gv; if (FindEntryFunc(mod, &gv) != nullptr) { @@ -201,7 +201,7 @@ Schedule ConcreteScheduleNode::Copy() { n->func_working_on_ = this->func_working_on_; n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); - n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful + n->analyzer_ = arith::Analyzer(); // new analyzer needed because it is stateful n->rand_state_ = ForkSeed(); return Schedule(std::move(n)); } diff --git a/src/s_tir/schedule/concrete_schedule.h b/src/s_tir/schedule/concrete_schedule.h index 848965208cc3..4fd6f80b4e8c 100644 --- a/src/s_tir/schedule/concrete_schedule.h +++ b/src/s_tir/schedule/concrete_schedule.h @@ -48,7 +48,7 @@ class ConcreteScheduleNode : public ScheduleNode { /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ - std::unique_ptr analyzer_; + arith::Analyzer analyzer_; /*! \brief The value of random state for sampling. */ LinearCongruentialEngine::TRandState rand_state_; diff --git a/src/s_tir/schedule/ir_comparator.cc b/src/s_tir/schedule/ir_comparator.cc index 1bb66a238104..5f83d276c720 100644 --- a/src/s_tir/schedule/ir_comparator.cc +++ b/src/s_tir/schedule/ir_comparator.cc @@ -96,7 +96,7 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { bool equal = n.same_as(other) || ((n->type_index() == other->type_index()) && n.dtype().code() == other.dtype().code() && ExprComparator::VisitExpr(n, other)) || - (ContainsVscaleCall(n) && analyzer_.CanProveEqual(n, other)); + (ContainsVscaleCall(n) && analyzer_->CanProveEqual(n, other)); if (!equal && assert_mode_) { std::ostringstream os; @@ -221,7 +221,7 @@ bool TensorizeComparator::VisitStmt_(const SBlockRealizeNode* op, const Stmt& ot bool TensorizeComparator::VisitStmt_(const SBlockNode* op, const Stmt& other) { const auto* rhs = other.as(); for (const IterVar& iter : op->iter_vars) { - lhs_analyzer_.Bind(iter->var, iter->dom); + lhs_analyzer_->Bind(iter->var, iter->dom); } // Check block equality. // All iter vars and buffer regions including the order should match. @@ -363,7 +363,7 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { equal_map_[lhs] = rhs; // Cast if necessary. This allows the workload and the tensor intrin to have different dtypes in // the indices. - analyzer_.Bind(lhs, cast(lhs.dtype(), rhs)); + analyzer_->Bind(lhs, cast(lhs.dtype(), rhs)); return true; } @@ -503,7 +503,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf // save base index indices_base.emplace_back(lhs->region[i + offset]->min); // check extent match - if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + if (!analyzer_->CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { if (assert_mode_) { std::ostringstream os; os << "CompareBufferRegion buffer extent mismatch: lhs->region[i + offset]=" @@ -529,7 +529,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - if (!lhs_analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { + if (!lhs_analyzer_->CanProveEqual(indices_base[i], lhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; os << "Buffer base index consistency check failed due to unequal index base: " @@ -542,7 +542,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } for (size_t i = 0; i < rhs->region.size(); i++) { // check extent match - if (!analyzer_.CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { + if (!analyzer_->CanProveEqual(lhs->region[i + offset]->extent, rhs->region[i]->extent)) { if (assert_mode_) { std::ostringstream os; os << "CompareBufferRegion buffer region extent mismatch. lhs->region[i + offset]=" @@ -552,8 +552,8 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf return false; } PrimExpr normalized_lhs_min = - lhs_analyzer_.Simplify((lhs->region[i + offset]->min - indices_base[i + offset])); - if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { + lhs_analyzer_->Simplify((lhs->region[i + offset]->min - indices_base[i + offset])); + if (!analyzer_->CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; os << "CompareBufferRegion buffer region min mismatch. lhs->region[i + offset]=" @@ -588,7 +588,7 @@ bool TensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { TVM_FFI_ICHECK_EQ(indices_base.size(), rhs->indices.size() + offset); for (size_t i = 0; i < rhs->indices.size(); i++) { PrimExpr normalized_lhs_index = lhs->indices[i + offset] - indices_base[i + offset]; - if (!analyzer_.CanProveEqual(normalized_lhs_index, rhs->indices[i])) { + if (!analyzer_->CanProveEqual(normalized_lhs_index, rhs->indices[i])) { if (assert_mode_) { std::ostringstream os; os << "CompareBufferAccess buffer indices mismatch. lhs->indices[i + offset]=" @@ -664,7 +664,7 @@ bool AutoTensorizeComparator::VisitStmt_(const SBlockNode* op, const Stmt& other } else { auto collect_iter = [&](const SBlockNode* op, std::vector& iters) -> bool { for (const auto& iter : op->iter_vars) { - analyzer_.Bind(iter->var, iter->dom); + analyzer_->Bind(iter->var, iter->dom); if (iter->iter_type == IterVarType::kDataPar || iter->iter_type == IterVarType::kCommReduce) { iters.push_back(iter); @@ -722,7 +722,7 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { } std::vector lhs_indices; for (const PrimExpr& index : lhs->indices) { - lhs_indices.push_back(SimplifyNonTrivialExpr(index, &analyzer_)); + lhs_indices.push_back(SimplifyNonTrivialExpr(index, analyzer_.get())); } auto is_scalar_access = [](const ffi::Array& indices, PrimExpr index) { @@ -749,7 +749,7 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) { return false; } for (size_t i = 0; i < indices.size(); ++i) { - if (!analyzer_.CanProveEqual(indices[i], old_indices[i])) { + if (!analyzer_->CanProveEqual(indices[i], old_indices[i])) { return false; } } diff --git a/src/s_tir/schedule/primitive/annotate_buffer_access.cc b/src/s_tir/schedule/primitive/annotate_buffer_access.cc index 82a3a0de1cfe..823ef42433c0 100644 --- a/src/s_tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/s_tir/schedule/primitive/annotate_buffer_access.cc @@ -96,13 +96,13 @@ void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int bu for (const IterVar& iter_var : block->iter_vars) { block_iter_vars.push_back(iter_var->var); } - ffi::Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); + ffi::Array new_indices = index_map->MapIndices(block_iter_vars, analyzer); TVM_FFI_ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even."; ffi::Array new_ranges; for (size_t i = 0; i < new_indices.size(); i += 2) { // (begin, end) represents a region new_ranges.push_back(Range::FromMinExtent( - new_indices[i], analyzer.Simplify(new_indices[i + 1] - new_indices[i]))); + new_indices[i], analyzer->Simplify(new_indices[i + 1] - new_indices[i]))); } BufferRegion new_region(buffer, new_ranges); diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index 4848c582c234..cf8108e870c4 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -164,7 +164,7 @@ ffi::Array> SubspaceDivide(const SBlockRealize& real const StmtSRef& block_sref, // const StmtSRef& loop_sref, // std::vector* loops, - arith::Analyzer* analyzer, + arith::AnalyzerObj* analyzer, bool preserve_unit_iters, bool loop_sref_as_outer = false) { ffi::Array inner_vars; @@ -188,7 +188,7 @@ ffi::Array> SubspaceDivide(const SBlockRealize& real } ffi::Array> result = arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate, - arith::IterMapLevel::Surjective, analyzer, + arith::IterMapLevel::Surjective, ffi::GetRef(analyzer), /*simplify_trivial_iterators=*/!preserve_unit_iters); if (!result.empty()) { return result; @@ -240,9 +240,9 @@ ffi::Map DeriveBlockBinding( IterVar outer_iter; if (reuse_outer) { outer_iter = outer_iter_vars->operator[](i); - TVM_FFI_ICHECK(ana.CanProveEqual(outer_iter->dom->extent, outer_mark->extent)); + TVM_FFI_ICHECK(ana->CanProveEqual(outer_iter->dom->extent, outer_mark->extent)); TVM_FFI_ICHECK( - ana.CanProveEqual(outer_bindings->operator[](i), NormalizeIterMapToExpr(outer_binding))); + ana->CanProveEqual(outer_bindings->operator[](i), NormalizeIterMapToExpr(outer_binding))); } else { outer_iter = IterVar(/*dom=*/RangeFromExtent(outer_mark->extent), /*var=*/iter_var->var.copy_with_suffix("_o"), @@ -382,10 +382,10 @@ Stmt GenerateOuterInit(const Stmt& block_init, const SBlockRealize& inner_realiz * \return The substituted stmt. */ Stmt Substitute(const Stmt& stmt, const ffi::Map& sub, - ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) { + ffi::Map* block_sref_reuse, arith::AnalyzerObj* analyzer) { struct Replacer : public StmtExprMutator { explicit Replacer(const ffi::Map& sub, - ffi::Map* block_sref_reuse, arith::Analyzer* analyzer) + ffi::Map* block_sref_reuse, arith::AnalyzerObj* analyzer) : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {} PrimExpr VisitExpr(const PrimExpr& op) final { @@ -414,7 +414,7 @@ Stmt Substitute(const Stmt& stmt, const ffi::Map& sub, const ffi::Map& sub_; ffi::Map* block_sref_reuse_; - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; }; return Replacer(sub, block_sref_reuse, analyzer)(stmt); } @@ -492,7 +492,7 @@ Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { } SBlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, - ffi::Map* block_sref_reuse, arith::Analyzer* analyzer, + ffi::Map* block_sref_reuse, arith::AnalyzerObj* analyzer, bool preserve_unit_iters) { TVM_SREF_TO_FOR(loop_sref); // Step 1: Check and get the only block under `loop`. @@ -565,7 +565,7 @@ StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_u arith::Analyzer analyzer; ffi::Map block_sref_reuse; SBlockRealize blockized = - BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters); + BlockizeImpl(self, loop_sref, &block_sref_reuse, analyzer.get(), preserve_unit_iters); self->Replace(loop_sref, blockized, block_sref_reuse); StmtSRef result = self->stmt2ref.at(blockized->block.get()); StmtSRef scope_root = GetScopeRoot(self, result, /*require_stage_pipeline=*/false); @@ -593,7 +593,7 @@ SBlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Array loops; ffi::Array> division = SubspaceDivide( - block_realize, block_sref, lca, &loops, &analyzer, preserve_unit_iters, true); + block_realize, block_sref, lca, &loops, analyzer.get(), preserve_unit_iters, true); if (division.empty()) { throw SubspaceNotDivisibleError(self->mod, ffi::GetRef(loops.back()), block); } @@ -617,10 +617,10 @@ SBlockRealize BlockizeBlocks(const ScheduleState& self, const ffi::Arraydom, loop_var_subst); inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(dom)); - analyzer.Bind(iter->var, dom); + analyzer->Bind(iter->var, dom); } SBlock block_subst = - Downcast(Substitute(block, block_var_subst, block_sref_reuse, &analyzer)); + Downcast(Substitute(block, block_var_subst, block_sref_reuse, analyzer.get())); auto reads = EvalSetRegions(block_subst->reads, inner_iter_dom); auto writes = EvalSetRegions(block_subst->writes, inner_iter_dom); read_regions.insert(read_regions.end(), reads.begin(), reads.end()); @@ -760,7 +760,8 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } else if (sref->stmt->IsInstance()) { arith::Analyzer analyzer; ffi::Map block_sref_reuse; - block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters); + block_realize = + BlockizeImpl(self, sref, &block_sref_reuse, analyzer.get(), preserve_unit_iters); } else { TVM_FFI_THROW(TypeError) << "Tensorize only support For or SBlock, but gets: " << ffi::GetRef(sref->stmt); @@ -768,7 +769,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } arith::Analyzer analyzer; - PrimFunc intrin_desc = StmtSimplify(intrin->desc, &analyzer); + PrimFunc intrin_desc = StmtSimplify(intrin->desc, analyzer.get()); PrimFunc intrin_impl = DeepCopy(intrin->impl); int index_dtype_bits = -1; diff --git a/src/s_tir/schedule/primitive/cache_index.cc b/src/s_tir/schedule/primitive/cache_index.cc index 3cd33aea0c51..4a6d4495f858 100644 --- a/src/s_tir/schedule/primitive/cache_index.cc +++ b/src/s_tir/schedule/primitive/cache_index.cc @@ -60,11 +60,11 @@ struct IndexInfo { */ DataType DetermineDatatype(const arith::IntSet& range) { arith::Analyzer ana; - if (ana.CanProve(range.min() >= INT32_MIN && range.max() <= INT32_MAX)) { + if (ana->CanProve(range.min() >= INT32_MIN && range.max() <= INT32_MAX)) { return DataType::Int(32); } else { - TVM_FFI_ICHECK(ana.CanProve(range.min() >= make_const(DataType::Int(64), INT64_MIN) && - range.max() <= make_const(DataType::Int(64), INT64_MAX))); + TVM_FFI_ICHECK(ana->CanProve(range.min() >= make_const(DataType::Int(64), INT64_MIN) && + range.max() <= make_const(DataType::Int(64), INT64_MAX))); return DataType::Int(64); } } @@ -483,7 +483,7 @@ ffi::Array CacheIndex(ScheduleState self, const StmtSRef& block_sref, StmtSRef parent_sref = ffi::GetRef(result_block_sref->parent); affine_binding = IsAffineBinding(/*realize=*/GetSBlockRealize(self, result_block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), - /*analyzer=*/&analyzer); + /*analyzer=*/analyzer.get()); } block_info.affine_binding = affine_binding; diff --git a/src/s_tir/schedule/primitive/cache_index_helpers.cc b/src/s_tir/schedule/primitive/cache_index_helpers.cc index 907c67ccb0c5..af721f07e3b2 100644 --- a/src/s_tir/schedule/primitive/cache_index_helpers.cc +++ b/src/s_tir/schedule/primitive/cache_index_helpers.cc @@ -393,7 +393,7 @@ bool EqualTerms(const PrimExpr& a, const PrimExpr& b) { PrimExpr NormalizeTerm(const PrimExpr& expr, bool do_normalization) { if (do_normalization) { arith::Analyzer analyzer; - return analyzer.Simplify(expr); + return analyzer->Simplify(expr); } else { return expr; } diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index 2cb9b5ac9484..ca854279c4ba 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -450,7 +450,7 @@ bool CalculateAffineFlag(const ScheduleState& self, const StmtSRef& block_sref) StmtSRef parent_sref = ffi::GetRef(block_sref->parent); return IsAffineBinding(/*realize=*/GetSBlockRealize(self, block_sref), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), - /*analyzer=*/&analyzer); + /*analyzer=*/analyzer.get()); } /*! @@ -632,7 +632,7 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& buffer_re /*predicate=*/Substitute(realize->predicate && extra_predicate, binding), /*dom_low_inclusive=*/dom_low_inclusive, /*dom_high_exclusive=*/dom_high_exclusive, - /*analyzer=*/&analyzer); + /*analyzer=*/analyzer.get()); TVM_FFI_ICHECK_EQ(buffer_region->region.size(), int_sets.size()); Region region; @@ -905,7 +905,7 @@ class CacheReadRewriter : public StmtExprMutator { TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { - ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), + ret.push_back(Range::FromMinExtent(ana_->Simplify(region[i]->min - offset[i]->min), region[i]->extent)); } return ret; @@ -1019,7 +1019,7 @@ class CacheReadRewriter : public StmtExprMutator { ffi::Array RewriteIndices(const ffi::Array& indices) { std::vector ret; for (size_t i = 0; i < indices.size(); ++i) { - ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); + ret.push_back(ana_->Simplify(indices[i] - info_->cache_region->region[i]->min)); } return ret; } @@ -1162,7 +1162,7 @@ class CacheWriteRewriter : public StmtExprMutator { TVM_FFI_ICHECK_EQ(region.size(), offset.size()); std::vector ret; for (size_t i = 0; i < region.size(); ++i) { - ret.push_back(Range::FromMinExtent(ana_.Simplify(region[i]->min - offset[i]->min), + ret.push_back(Range::FromMinExtent(ana_->Simplify(region[i]->min - offset[i]->min), region[i]->extent)); } return ret; @@ -1289,7 +1289,7 @@ class CacheWriteRewriter : public StmtExprMutator { ffi::Array RewriteIndices(const ffi::Array& indices) { std::vector ret; for (size_t i = 0; i < indices.size(); ++i) { - ret.push_back(ana_.Simplify(indices[i] - info_->cache_region->region[i]->min)); + ret.push_back(ana_->Simplify(indices[i] - info_->cache_region->region[i]->min)); } return ret; } @@ -1990,8 +1990,8 @@ void CollectReindexCacheStageInfoAndCreateBuffer( block_iter_vars.push_back(iter_var); block_shape.push_back(iter_var->dom->extent); } - ffi::Array new_indices = index_map->MapIndices(block_iter_vars, &analyzer); - ffi::Array new_shape = index_map->MapShape(block_shape, &analyzer); + ffi::Array new_indices = index_map->MapIndices(block_iter_vars, analyzer); + ffi::Array new_shape = index_map->MapShape(block_shape, analyzer); info->indices = new_indices; // Step 5. Update CacheTouchedInfo @@ -2323,10 +2323,10 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde ffi::Array original_indices = ReIndexCollector::Collect(self->mod, buffer, block); // Simplify the indices if possible for (const IterVar& iter : block->iter_vars) { - analyzer.Bind(iter->var, iter->dom); + analyzer->Bind(iter->var, iter->dom); } original_indices.MutateByApply( - [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); + [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, analyzer.get()); }); // Collect block iters appearing in the original_indices std::unordered_set covered; diff --git a/src/s_tir/schedule/primitive/compute_at.cc b/src/s_tir/schedule/primitive/compute_at.cc index 79dd56241cf1..2d0a1b960b6a 100644 --- a/src/s_tir/schedule/primitive/compute_at.cc +++ b/src/s_tir/schedule/primitive/compute_at.cc @@ -82,7 +82,7 @@ class NotInSameScopeError : public ScheduleError { public: static void CheckAndBindLoopDomain(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, const StmtSRef& scope_root_sref, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) { if (const ForNode* loop = p->StmtAs()) { analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); @@ -201,7 +201,7 @@ struct BlockVarDomainInfo { } /*! \brief Simplify domain info */ - void Simplify(arith::Analyzer* analyzer) { + void Simplify(arith::AnalyzerObj* analyzer) { auto to_simplified = [analyzer](const arith::IntSet& set) { PrimExpr min = set.HasLowerBound() ? analyzer->Simplify(set.min()) : set.min(); PrimExpr max = set.HasUpperBound() ? analyzer->Simplify(set.max()) : set.max(); @@ -255,7 +255,7 @@ class ScopeReconstructor : private StmtMutator { * \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1 */ void MakeNewLoop(int insert_position, std::vector iter_doms, - arith::Analyzer* analyzer, bool preserve_unit_loops) { + arith::AnalyzerObj* analyzer, bool preserve_unit_loops) { int n_iters = iter_doms.size(); ffi::Array loop_vars; ffi::Array loop_extents; @@ -409,7 +409,7 @@ void RelaxBufferRegions(const ffi::Map& binding, std::pair SolveBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required, PrimExpr dim_max, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { PrimExpr provided_min = analyzer->Simplify(provided.min()); PrimExpr provided_max = analyzer->Simplify(provided.max()); PrimExpr required_min = analyzer->Simplify(required.min()); @@ -484,15 +484,17 @@ std::pair SolveBlockVarDomain(const arith::IntSet& prov */ void UpdateBlockVarDomainDimwise( const BufferNode* buffer, const NDIntSet& provided_region, const NDIntSet& required_region, - arith::Analyzer* analyzer, std::unordered_map* iter_doms) { + arith::AnalyzerObj* analyzer, + std::unordered_map* iter_doms) { size_t ndim = buffer->shape.size(); for (size_t i = 0; i < ndim; ++i) { arith::IntSet provided = provided_region[i]; arith::IntSet required = required_region[i]; PrimExpr dim_max = max(buffer->shape[i] - 1, 0); + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); - if (provided.CanProveSinglePoint(analyzer) && is_const_int(provided.min())) { - TVM_FFI_ICHECK(required.CanProveSinglePoint(analyzer) && + if (provided.CanProveSinglePoint(analyzer_ref) && is_const_int(provided.min())) { + TVM_FFI_ICHECK(required.CanProveSinglePoint(analyzer_ref) && analyzer->CanProveEqual(provided.min(), required.min())); continue; } @@ -511,7 +513,7 @@ void UpdateBlockVarDomainDimwise( /*! \brief Helper function to implement intset version of `InverseAffineIterMap`. */ ffi::Map InverseAffineIterMap(const ffi::Array& iter_map, const NDIntSet& outputs, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { ffi::Array min_point, max_point; min_point.reserve(outputs.size()); max_point.reserve(outputs.size()); @@ -549,11 +551,12 @@ ffi::Map InverseAffineIterMap(const ffi::Array& iter_vars, const NDIntSet& provided_region, const NDIntSet& required_region, - arith::Analyzer* analyzer, + arith::AnalyzerObj* analyzer, std::unordered_map* iter_doms) { // we only support single point provided region now, which could cover most cases + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); for (const auto& intset : provided_region) { - if (!intset.CanProveSinglePoint(analyzer)) return false; + if (!intset.CanProveSinglePoint(analyzer_ref)) return false; } // calculate forward mapping (block vars -> provided region point) ffi::Map dom_map; @@ -567,7 +570,7 @@ bool UpdateBlockVarDomainAffine(const BufferNode* buffer, const ffi::Arrayindices.empty()) { return false; } @@ -602,7 +605,7 @@ std::vector CalculateBlockVarDomain( const ffi::Array& iter_vars, std::unordered_map> provided_regions, std::unordered_map> required_regions, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { int n_iters = iter_vars.size(); // Step 1. Construct the mapping from block var to their iteration domain (initialized to empty) std::unordered_map iter_doms; @@ -693,7 +696,7 @@ void CalculateProvidedRequiredRegions( template void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, - arith::Analyzer* analyzer, bool check_only = false, + arith::AnalyzerObj* analyzer, bool check_only = false, int index = -1) { const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); @@ -768,15 +771,15 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, int index) { arith::Analyzer analyzer; - ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, &analyzer, - false, index); + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, + analyzer.get(), false, index); } void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops, int index) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer, false, index); + analyzer.get(), false, index); } bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, @@ -784,7 +787,7 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer, true); + analyzer.get(), true); } catch (const tvm::ffi::Error& e) { return false; } @@ -796,7 +799,7 @@ bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer, true); + analyzer.get(), true); } catch (const tvm::ffi::Error& e) { return false; } diff --git a/src/s_tir/schedule/primitive/compute_inline.cc b/src/s_tir/schedule/primitive/compute_inline.cc index 20043b720a39..08a990b62da9 100644 --- a/src/s_tir/schedule/primitive/compute_inline.cc +++ b/src/s_tir/schedule/primitive/compute_inline.cc @@ -480,8 +480,8 @@ class ComputeInliner : public BaseInliner { const IterVar& iter = producer_block->iter_vars[i]; const PrimExpr& e = inlined_store_->indices[i]; if (e.same_as(iter->var) || - (analyzer_.CanProveEqual(e, 0) && analyzer_.CanProveEqual(iter->dom->min, 0) && - analyzer_.CanProveEqual(iter->dom->extent, 1))) { + (analyzer_->CanProveEqual(e, 0) && analyzer_->CanProveEqual(iter->dom->min, 0) && + analyzer_->CanProveEqual(iter->dom->extent, 1))) { idx_vars.push_back(iter->var); } else { break; @@ -505,7 +505,7 @@ class ComputeInliner : public BaseInliner { /*input_iters=*/producer_iter_doms, /*predicate=*/true, /*check_level=*/arith::IterMapLevel::Bijective, - /*analyzer=*/&analyzer_, + /*analyzer=*/analyzer_, /*simplify_trivial_iterators=*/false); if (!res->errors.empty()) { // Failure: indices of BufferStore are not bijective affine @@ -518,7 +518,7 @@ class ComputeInliner : public BaseInliner { auto inverse_iter_map = arith::InverseAffineIterMap( res->indices, ffi::Array(idx_vars_.begin(), idx_vars_.end())); for (const auto& iter : producer_block->iter_vars) { - if (is_const_int(iter->dom->min) && analyzer_.CanProveEqual(iter->dom->extent, 1)) { + if (is_const_int(iter->dom->min) && analyzer_->CanProveEqual(iter->dom->extent, 1)) { // fallback mapping for constant iters inverse_iter_map.Set(iter->var, iter->dom->min); } @@ -671,7 +671,7 @@ class ReverseComputeInliner : public BaseInliner { /*input_iters=*/consumer_iter_doms, /*predicate=*/true, /*check_level=*/arith::IterMapLevel::NoCheck, - /*analyzer=*/&analyzer_, + /*analyzer=*/analyzer_, /*simplify_trivial_iterators=*/false); buffer_load_iter_map_ = res->indices; if (buffer_load_iter_map_.empty()) { @@ -721,12 +721,12 @@ class ReverseComputeInliner : public BaseInliner { const IterVar& iter = producer_block->iter_vars[i]; const PrimExpr& binding = producer_block_realize->iter_values[i]; subst_map.Set(iter->var, binding); - analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); + analyzer_->Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); } if (producer_block->annotations.count(s_tir::attr::auto_copy) != 0) { auto bind = [&](const ForNode* loop) { - analyzer_.Bind(loop->loop_var, - Range::FromMinExtent(make_zero(loop->extent->dtype), loop->extent)); + analyzer_->Bind(loop->loop_var, + Range::FromMinExtent(make_zero(loop->extent->dtype), loop->extent)); }; const ForNode* producer_inner_loop = producer_block->body.as(); while (producer_inner_loop->body.as()) { @@ -738,15 +738,15 @@ class ReverseComputeInliner : public BaseInliner { // Substitute the consumer block iters with the corresponding iters in the producer blocks PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_); // Simplify the predicate using the producer block iter domains - predicate = analyzer_.Simplify(predicate); + predicate = analyzer_->Simplify(predicate); if (is_one(predicate)) { return producer_block_realize; } if (const auto* if_ = producer_block->body.as()) { if (!if_->else_case.defined()) { - PrimExpr if_predicate = analyzer_.Simplify(if_->condition); + PrimExpr if_predicate = analyzer_->Simplify(if_->condition); if (!ffi::StructuralEqual()(predicate, if_predicate)) { - predicate = analyzer_.Simplify(predicate && if_->condition); + predicate = analyzer_->Simplify(predicate && if_->condition); producer_block.CopyOnWrite()->body = if_->then_case; } } @@ -754,7 +754,7 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr outer_predicate = Substitute(predicate, subst_map); auto n = producer_block_realize.CopyOnWrite(); n->block = producer_block; - n->predicate = analyzer_.Simplify(outer_predicate); + n->predicate = analyzer_->Simplify(outer_predicate); return ffi::GetRef(n); } @@ -790,9 +790,9 @@ class ReverseComputeInliner : public BaseInliner { if (auto it = idx_sub_.find(iter->var.get()); it != idx_sub_.end()) { const PrimExpr& producer_iter = it->second; arith::IntSet producer_iter_range = arith::EvalSet(producer_iter, producer_iter_doms); - if (analyzer_.CanProve(producer_iter_range.min() > iter->dom->min) || - analyzer_.CanProve(producer_iter_range.max() < - iter->dom->min + iter->dom->extent - 1)) { + if (analyzer_->CanProve(producer_iter_range.min() > iter->dom->min) || + analyzer_->CanProve(producer_iter_range.max() < + iter->dom->min + iter->dom->extent - 1)) { return false; } } else { @@ -972,7 +972,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block /*realize=*/GetSBlockRealize(self, producer_block_sref), /*loop_var_ranges=*/ LoopDomainOfSRefTreePath(ffi::GetRef(producer_block_sref->parent)), - /*analyzer=*/&analyzer); + /*analyzer=*/analyzer.get()); } bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { @@ -1299,7 +1299,7 @@ SBlock ReductionEpilogueFuser::CreateFusedReductionBlock( // Simplify the expression (e.g., 0 + C[vi, vj] -> C[vi, vj]) arith::Analyzer analyzer; - init_epilogue = analyzer.Simplify(init_epilogue); + init_epilogue = analyzer->Simplify(init_epilogue); BufferStore new_init_store = BufferStore(epilogue_output_buffer_, init_epilogue, Substitute(epilogue_output_indices_, var_map)); diff --git a/src/s_tir/schedule/primitive/decompose_padding.cc b/src/s_tir/schedule/primitive/decompose_padding.cc index ee2045b7eef6..c67b2afbb4ba 100644 --- a/src/s_tir/schedule/primitive/decompose_padding.cc +++ b/src/s_tir/schedule/primitive/decompose_padding.cc @@ -71,7 +71,7 @@ class PaddingInfoAnalyzer { public: static PaddingSBlockInfo CheckAndGetPaddingInfo(IRModule mod, const SBlockRealizeNode* realize, const ffi::Map& dom_map, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { PaddingInfoAnalyzer padding_analyzer(analyzer); if (!padding_analyzer.MatchPadding(realize, dom_map)) { throw PaddingPatternMatchError(mod, realize->block, padding_analyzer.error_msg_); @@ -80,7 +80,7 @@ class PaddingInfoAnalyzer { } private: - explicit PaddingInfoAnalyzer(arith::Analyzer* analyzer) : analyzer_(analyzer) {} + explicit PaddingInfoAnalyzer(arith::AnalyzerObj* analyzer) : analyzer_(analyzer) {} /*! \brief Detect padding pattern and update result. */ bool MatchPadding(const SBlockRealizeNode* realize, const ffi::Map& dom_map) { @@ -164,8 +164,9 @@ class PaddingInfoAnalyzer { const PrimExpr& in_bound_predicate) { ffi::Array region; + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer_); auto res = arith::DetectIterMap(iter_values, dom_map, in_bound_predicate, - arith::IterMapLevel::Surjective, analyzer_); + arith::IterMapLevel::Surjective, analyzer_ref); if (res->indices.empty()) { SetError("Block iters are not independent wrt padding condition"); return {}; @@ -192,7 +193,7 @@ class PaddingInfoAnalyzer { /*! \brief current error message. */ std::string error_msg_; /*! \brief arithmetic analyzer. */ - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; }; /*! \brief Create block to fill constant pad values into full region */ @@ -200,7 +201,7 @@ static std::pair CreateConstBlock(const SBlockRealizeNode* const PaddingSBlockInfo& info, const ffi::Array& loops, const Stmt& highest_pos_inclusive, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { const SBlock& block = realize->block; ffi::Array new_iter_vars; ffi::Map repl_dict; @@ -269,7 +270,7 @@ static std::pair CreateInBoundBlock(const SBlockRealizeNode const ffi::Array& loops, const Stmt& highest_pos_inclusive, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { const SBlock& block = realize->block; ffi::Array new_iter_vars; ffi::Map repl_dict; @@ -435,7 +436,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, For cur_loop = ffi::GetRef((*it)->StmtAs()); Range range = Range::FromMinExtent(cur_loop->min, cur_loop->extent); dom_map.Set(cur_loop->loop_var, range); - analyzer.Bind(cur_loop->loop_var, range); + analyzer->Bind(cur_loop->loop_var, range); loops.push_back(cur_loop); if (cur_loop.same_as(const_filling_pos)) { @@ -462,7 +463,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, // Check 3. match padding pattern and return padding operation info. PaddingSBlockInfo info = - PaddingInfoAnalyzer::CheckAndGetPaddingInfo(self->mod, realize, dom_map, &analyzer); + PaddingInfoAnalyzer::CheckAndGetPaddingInfo(self->mod, realize, dom_map, analyzer.get()); // IR Manipulation // Step 1. Create const pad value filling part and in-bound value filling part. @@ -470,9 +471,9 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref, replace_desc.const_filling_pos = const_filling_pos; replace_desc.in_bound_filling_pos = in_bound_filling_pos; std::tie(replace_desc.const_filling_loop, replace_desc.const_filling_block) = - CreateConstBlock(realize, info, loops, const_filling_pos, &analyzer); + CreateConstBlock(realize, info, loops, const_filling_pos, analyzer.get()); std::tie(replace_desc.in_bound_filling_loop, replace_desc.in_bound_filling_block) = - CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, &analyzer); + CreateInBoundBlock(realize, info, loops, in_bound_filling_pos, analyzer.get()); // Step 2. Execute IR replacement. SBlock old_scope_root_block = ffi::GetRef(scope_root_sref->StmtAs()); diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index 9878828e3eb9..1a44cf1ff4e7 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -98,7 +98,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { static TransformPlan Plan(SBlock block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, - ffi::Optional pad_value, arith::Analyzer* analyzer) { + ffi::Optional pad_value, arith::AnalyzerObj* analyzer) { TVM_FFI_ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) << "Internal error: Should be caught by ScheduleError checks prior to this point"; TransformLayoutPlanner visitor(old_buffer); @@ -225,7 +225,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { public: BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate, const IndexMap& inverse, const ffi::Optional& pad_value, - ffi::Map* new_block_to_old, arith::Analyzer* analyzer) + ffi::Map* new_block_to_old, arith::AnalyzerObj* analyzer) : info(info), new_buffer(new_buffer), new_indices(inverse->initial_indices), @@ -359,7 +359,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { if (can_replace) { ffi::Array new_index_exprs = new_indices.Map([](const auto& var) -> PrimExpr { return var; }); - PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_index_exprs, analyzer)[0]; + PrimExpr pad_value_at_index = pad_value.value()->MapIndices( + new_index_exprs, ffi::GetRef(analyzer))[0]; store = BufferStore(new_buffer, if_then_else(padding_predicate, pad_value_at_index, op->value), new_index_exprs); @@ -435,14 +436,14 @@ class TransformLayoutPlanner : private StmtExprVisitor { const ffi::Optional& pad_value; ffi::Map& new_block_to_old; bool all_stores_replaced{true}; - arith::Analyzer* analyzer; + arith::AnalyzerObj* analyzer; ffi::Map var_remap; }; TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, ffi::Optional pad_value, - arith::Analyzer* analyzer) const { + arith::AnalyzerObj* analyzer) const { if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value, analyzer); prologue_plan.has_value()) { @@ -463,7 +464,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, ffi::Optional pad_value, - arith::Analyzer* analyzer) const { + arith::AnalyzerObj* analyzer) const { if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -485,7 +486,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { } padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); - PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices, analyzer)[0]; + PrimExpr pad_value_at_index = + pad_value.value()->MapIndices(indices, ffi::GetRef(analyzer))[0]; PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value_at_index); Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr})); @@ -508,7 +510,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { IndexMap inverse, PrimExpr padding_predicate, ffi::Optional pad_value, - arith::Analyzer* analyzer) const { + arith::AnalyzerObj* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -558,7 +560,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, IndexMap inverse, PrimExpr padding_predicate, ffi::Optional pad_value, - arith::Analyzer* analyzer) const { + arith::AnalyzerObj* analyzer) const { if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { return std::nullopt; } @@ -577,7 +579,8 @@ class TransformLayoutPlanner : private StmtExprVisitor { iter_values.push_back(loop_var); } - PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices, analyzer)[0]; + PrimExpr pad_value_at_index = + pad_value.value()->MapIndices(indices, ffi::GetRef(analyzer))[0]; Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices); std::stringstream block_name; @@ -759,10 +762,10 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, opt_inverse.value(), padding_predicate, - pad_value, &analyzer) + pad_value, analyzer.get()) : TransformLayoutPlanner::NoPaddingRequired(); - TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); + TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, analyzer.get()); SBlock result = Downcast(rewriter(scope_stmt)); if (auto plan_ptr = std::get_if(&plan)) { auto write_ptr = result.CopyOnWrite(); @@ -779,7 +782,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer, const IndexMap& index_map, const TransformLayoutPlanner::TransformPlan& plan, - arith::Analyzer* analyzer) + arith::AnalyzerObj* analyzer) : IRMutatorWithAnalyzer(analyzer), old_buffer_(old_buffer), new_buffer_(new_buffer), @@ -793,7 +796,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { void RewriteBufferAccess(Buffer* buffer, ffi::Array* indices) { *buffer = new_buffer_; - *indices = index_map_->MapIndices(*indices, &index_simplifier_); + *indices = index_map_->MapIndices(*indices, index_simplifier_); *indices = this->IterMapSimplifyWithContext(*indices, true); } @@ -1088,7 +1091,7 @@ class TransformationIntroducesPaddingError : public ScheduleError { ffi::String DetailRenderTemplate() const final { arith::Analyzer analyzer; - auto new_shape = index_map_->MapShape(buffer_->shape, &analyzer); + auto new_shape = index_map_->MapShape(buffer_->shape, analyzer); std::ostringstream os; os << "The transformation " << index_map_ << " applied on buffer " << buffer_->name << " of shape " << buffer_->shape << " would result in shape " << new_shape @@ -1158,7 +1161,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ BufferIndexType buffer_index_type, const IndexMap& index_map_orig, const ffi::Optional& pad_value, bool assume_injective_transform) { arith::Analyzer analyzer; - AddShapeVarBounds(self, block_sref.get(), &analyzer); + AddShapeVarBounds(self, block_sref.get(), analyzer.get()); // Step 1: Input handling and error checking const SBlockNode* block_ptr = TVM_SREF_TO_SBLOCK(block_sref); Buffer old_buffer = @@ -1194,7 +1197,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ for (const auto& dim : old_buffer->shape) { region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim)); } - return index_map.NonSurjectiveInverse(region, &analyzer); + return index_map.NonSurjectiveInverse(region, analyzer); }(); } @@ -1205,7 +1208,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ // Step 2: Infer the shape of the new buffer Buffer new_buffer = old_buffer; - new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape, &analyzer); + new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape, analyzer); // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block // alloc_buffers. @@ -1360,13 +1363,13 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, const SBlockNode* block_ptr = TVM_SREF_TO_SBLOCK(block_sref); const SBlock& block = ffi::GetRef(block_ptr); arith::Analyzer analyzer; - AddShapeVarBounds(self, block_sref.get(), &analyzer); + AddShapeVarBounds(self, block_sref.get(), analyzer.get()); // Step 1: Collect outer loops and loop vars ffi::Array loops = GetLoops(block_sref); // outer loops of the block std::unordered_set loop_vars; // loop vars of the outer loops for (const StmtSRef& loop_sref : loops) { - CheckLoopStartsWithZero(self, loop_sref, &analyzer); + CheckLoopStartsWithZero(self, loop_sref, analyzer.get()); loop_vars.emplace(loop_sref->StmtAs()->loop_var.get()); } @@ -1400,9 +1403,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, // Step 4: Apply the IndexMap to block iters. IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map); - ffi::Array transformed_block_iters = index_map->MapIndices(block_vars, &analyzer); - ffi::Array new_block_iter_range = - index_map->MapShape(block_iter_range_array, &analyzer); + ffi::Array transformed_block_iters = index_map->MapIndices(block_vars, analyzer); + ffi::Array new_block_iter_range = index_map->MapShape(block_iter_range_array, analyzer); // Step 5: Create the new block after transformation. @@ -1440,13 +1442,13 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, } IndexMap inverse_index_map{nullptr}; try { - inverse_index_map = index_map.Inverse(initial_ranges, &analyzer); + inverse_index_map = index_map.Inverse(initial_ranges, analyzer); } catch (...) { throw NotBijectiveAffineIndexMapError(self->mod, index_map); } // old block vars written in terms of new block vars ffi::Array inversed_new_block_vars = - inverse_index_map->MapIndices(new_block_vars, &analyzer); + inverse_index_map->MapIndices(new_block_vars, analyzer); for (int i = 0, n = block_vars.size(); i < n; ++i) { inverse_subst_map.Set(Downcast(block_vars[i]), inversed_new_block_vars[i]); } @@ -1454,7 +1456,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, SBlock new_block = Downcast(Substitute(ffi::GetRef(block_ptr), inverse_subst_map)); new_block.CopyOnWrite()->iter_vars = new_block_iters; - new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); + new_block = Downcast(BlockBufferAccessSimplifier::Simplify(new_block, analyzer.get())); // Step 5.3: Create outer loops for each new block iter. diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc b/src/s_tir/schedule/primitive/loop_transformation.cc index 8011b09d0c29..649f63ab88c9 100644 --- a/src/s_tir/schedule/primitive/loop_transformation.cc +++ b/src/s_tir/schedule/primitive/loop_transformation.cc @@ -125,7 +125,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { /*input_iters=*/loop_var2extent_, /*input_pred=*/op->predicate, /*check_level=*/arith::IterMapLevel::Surjective, - /*analyzer=*/&analzyer_, + /*analyzer=*/analzyer_, /*simplify_trivial_iterators=*/!preserve_unit_iters_); if (v.same_as(op->iter_values)) { return ffi::GetRef(op); @@ -407,7 +407,7 @@ ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, } // Currently, loops not starting with 0 are not supported arith::Analyzer analyzer; - CheckLoopStartsWithZero(self, loop_sref, &analyzer); + CheckLoopStartsWithZero(self, loop_sref, analyzer.get()); // Find the most common dtype DataType dtype; @@ -426,7 +426,7 @@ ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)).copy_with_dtype(dtype); substitute_value = substitute_value * factor + var; - analyzer.Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor))); + analyzer->Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor))); new_loop_vars.emplace_back(std::move(var)); } ffi::Map opaque_block_reuse; @@ -442,7 +442,8 @@ ffi::Array Split(ScheduleState self, const StmtSRef& loop_sref, &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Update predicate to guard the loop PrimExpr predicate = substitute_value < loop->extent; - if (!disable_predication && !analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { + if (!disable_predication && + !analyzer->CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); } // Step 4. Generate nested loops to replace the original loop and simplify the binding @@ -672,7 +673,7 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref // Iterate over each pair of factors and create partition for (int i = 0; i < n; i++) { - extent_value = analyzer.Simplify(factors[i]); + extent_value = analyzer->Simplify(factors[i]); Var new_loop_var = loop->loop_var.copy_with_suffix(std::to_string(i)).copy_with_dtype(dtype); Stmt loop_body = tirx::Substitute(loop->body, {{loop->loop_var, new_loop_var}}); @@ -826,7 +827,7 @@ StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs) { if (!loop->annotations.empty() || loop->thread_binding.defined()) { throw HasAnnotationOrThreadBindingError(self->mod, ffi::GetRef(loop)); } - CheckLoopStartsWithZero(self, ffi::GetRef(p), &analyzer); + CheckLoopStartsWithZero(self, ffi::GetRef(p), analyzer.get()); nest_loop_i_loops.push_back(ffi::GetRef(loop)); nest_loop_i_extents.push_back(loop->extent); } @@ -853,7 +854,7 @@ StmtSRef Merge(ScheduleState self, const ffi::Array& loop_srefs) { throw; } else { for (size_t j = 0; j < nest_loop_i_extents.size(); j++) { - if (!analyzer.CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) { + if (!analyzer->CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) { TVM_FFI_THROW(ScheduleError) << "Merge loop's `extent` must be same, but not." << " extent=[" << j << "," << nest_loop_extents[j] << "," << nest_loop_i_extents[j] << "]"; @@ -901,7 +902,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, } outer_loop_sref = sref; outer_loop = loop; - CheckLoopStartsWithZero(self, sref, &analyzer); + CheckLoopStartsWithZero(self, sref, analyzer.get()); const VarNode* used_var = nullptr; auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) { if (outer_loop_vars.count(var)) { @@ -932,7 +933,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, substitute_value.resize(loops.size()); PrimExpr lower = 1; for (int i = static_cast(loops.size()) - 1; i > 0; i--) { - PrimExpr next_lower = analyzer.canonical_simplify(loops[i]->extent * lower); + PrimExpr next_lower = analyzer->canonical_simplify(loops[i]->extent * lower); substitute_value.Set( i, is_one(loops[i]->extent) ? 0 : floordiv(floormod(fused_var, next_lower), lower)); lower = next_lower; @@ -955,7 +956,7 @@ StmtSRef Fuse(ScheduleState self, const ffi::Array& loop_srefs, for (int i = 0; i < n; i++) { fused_extent *= loops[i]->extent; } - fused_extent = analyzer.Simplify(fused_extent); + fused_extent = analyzer->Simplify(fused_extent); new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite(), diff --git a/src/s_tir/schedule/primitive/pad_einsum.cc b/src/s_tir/schedule/primitive/pad_einsum.cc index e805ff1e7df3..33fd5390e81f 100644 --- a/src/s_tir/schedule/primitive/pad_einsum.cc +++ b/src/s_tir/schedule/primitive/pad_einsum.cc @@ -159,7 +159,7 @@ struct BufferPadding { return result; } - Stmt MakeCopyBlock(bool is_read, ffi::Array* blocks, arith::Analyzer* analyzer) { + Stmt MakeCopyBlock(bool is_read, ffi::Array* blocks, arith::AnalyzerObj* analyzer) { ffi::Array loop_vars; ffi::Array loop_doms; ffi::Array iter_vars; @@ -390,8 +390,8 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array< const IterVar& iter = block->iter_vars[i]; PrimExpr dom = iter->dom->extent; PrimExpr pad_imm = IntImm(dom->dtype, padding[i]); - PrimExpr new_dom = analyzer.Simplify(ceildiv(dom, pad_imm) * pad_imm); - if (!analyzer.CanProveEqual(new_dom, dom)) { + PrimExpr new_dom = analyzer->Simplify(ceildiv(dom, pad_imm) * pad_imm); + if (!analyzer->CanProveEqual(new_dom, dom)) { replacer.iter2padded_extents.Set(iter->var, new_dom); if (const auto* loop_var = realize->iter_values[i].as()) { replacer.iter2padded_extents.Set(ffi::GetRef(loop_var), new_dom); @@ -441,7 +441,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array< BufferPadding bp = BufferPadding::FromBufferRegion(buffer_region, replacer.iter2padded_extents); replacer.buffer_map_.Set(bp.buffer, bp.padded_buffer); - read_blocks.push_back(bp.MakeCopyBlock(true, &new_copy_blocks, &analyzer)); + read_blocks.push_back(bp.MakeCopyBlock(true, &new_copy_blocks, analyzer.get())); alloc_buffers.push_back(bp.padded_buffer); } } @@ -450,7 +450,7 @@ void PadEinsum(ScheduleState self, const StmtSRef& block_sref, const ffi::Array< BufferPadding bp = BufferPadding::FromBufferRegion(buffer_region, replacer.iter2padded_extents); replacer.buffer_map_.Set(bp.buffer, bp.padded_buffer); - write_blocks.push_back(bp.MakeCopyBlock(false, &new_copy_blocks, &analyzer)); + write_blocks.push_back(bp.MakeCopyBlock(false, &new_copy_blocks, analyzer.get())); alloc_buffers.push_back(bp.padded_buffer); } } diff --git a/src/s_tir/schedule/primitive/read_write_at.cc b/src/s_tir/schedule/primitive/read_write_at.cc index 7a9e00cbf371..9f0554a53185 100644 --- a/src/s_tir/schedule/primitive/read_write_at.cc +++ b/src/s_tir/schedule/primitive/read_write_at.cc @@ -332,7 +332,7 @@ struct ReadWriteAtImpl { dst_(dst), annotations_(annotations), block_sref_reuse_(), - analyzer_(std::make_unique()) { + analyzer_(arith::Analyzer()) { loop_ = TVM_SREF_TO_FOR(loop_sref); } @@ -343,7 +343,7 @@ struct ReadWriteAtImpl { const Buffer& dst_; ffi::Map annotations_; ffi::Map block_sref_reuse_; - std::unique_ptr analyzer_; + arith::Analyzer analyzer_; }; StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, diff --git a/src/s_tir/schedule/primitive/rolling_buffer.cc b/src/s_tir/schedule/primitive/rolling_buffer.cc index 402cb8aef106..d8e39b95ff85 100644 --- a/src/s_tir/schedule/primitive/rolling_buffer.cc +++ b/src/s_tir/schedule/primitive/rolling_buffer.cc @@ -351,8 +351,8 @@ class RollingBufferRewriter : public StmtExprMutator { std::make_pair(var, arith::IntSet::Interval(0, 0))}; auto iter_value = realize->iter_values[i]; arith::Analyzer analyzer; - auto term_2 = analyzer.int_set(iter_value, dmap).min(); - condition = analyzer.Simplify( + auto term_2 = analyzer->int_set(iter_value, dmap).min(); + condition = analyzer->Simplify( And(condition, Or(LT(var, 1), GE(term_2, info_->axis_overlaps[i])))); } } diff --git a/src/s_tir/schedule/state.cc b/src/s_tir/schedule/state.cc index 6ddc3358106b..865a181a8752 100644 --- a/src/s_tir/schedule/state.cc +++ b/src/s_tir/schedule/state.cc @@ -45,15 +45,16 @@ ffi::Array AnalyzeRegionUpperBound(const BufferRegion& region, const PrimExpr& predicate, // const StmtSRef& dom_low_inclusive, // const StmtSRef& dom_high_exclusive, // - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { ffi::Map var_dom = LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); return EstimateRegionUpperBound( /*region=*/region->region, /*var_dom=*/var_dom, - /*predicate=*/predicate, /*analyzer=*/analyzer); + /*predicate=*/predicate, /*analyzer=*/analyzer_ref); } /*! @@ -70,15 +71,16 @@ ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, const PrimExpr& predicate, // const StmtSRef& dom_low_inclusive, // const StmtSRef& dom_high_exclusive, // - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { ffi::Map var_dom = LoopDomainOfSRefTreePath( /*low_inclusive=*/dom_low_inclusive, /*high_exclusive=*/dom_high_exclusive, /*extra_relax_scope=*/runtime::StorageScope::Create(region->buffer.scope())); + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); if (ffi::Optional> result = EstimateRegionLowerBound( /*region=*/region->region, /*var_dom=*/var_dom, - /*predicate=*/predicate, /*analyzer=*/analyzer)) { + /*predicate=*/predicate, /*analyzer=*/analyzer_ref)) { return result.value(); } return ffi::Array(region->buffer->shape.size(), arith::IntSet::Nothing()); @@ -95,7 +97,7 @@ ffi::Array AnalyzeRegionLowerBound(const BufferRegion& region, bool ProducerCoversConsumer(const ffi::Array& buffer_shape, const ffi::Array& produced_region, const ffi::Array& consumed_region, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { TVM_FFI_ICHECK_EQ(buffer_shape.size(), consumed_region.size()); TVM_FFI_ICHECK_EQ(produced_region.size(), consumed_region.size()); int ndim = produced_region.size(); @@ -191,7 +193,7 @@ class SBlockInfoCollector : private StmtVisitor { info.affine_binding = IsAffineBinding(/*realize=*/block2realize_.at(scope_root->stmt), /*loop_var_ranges=*/LoopDomainOfSRefTreePath(srefs_.back()), - /*analyzer=*/&analyzer_); + /*analyzer=*/analyzer_.get()); } // Set `region_cover` to true, will be updated on its scope block info.region_cover = true; @@ -296,7 +298,7 @@ class SBlockInfoCollector : private StmtVisitor { /*predicate=*/producer_realize->predicate, /*dom_low_inclusive=*/parent_sref, /*dom_high_exclusive=*/lca, - /*analyzer=*/&analyzer_)); + /*analyzer=*/analyzer_.get())); } } } @@ -315,9 +317,9 @@ class SBlockInfoCollector : private StmtVisitor { /*predicate=*/consumer_realize->predicate, /*dom_low_inclusive=*/parent_sref, /*dom_high_exclusive=*/lca, - /*analyzer=*/&analyzer_); + /*analyzer=*/analyzer_.get()); if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, - &analyzer_)) { + analyzer_.get())) { region_cover = false; self_->block_info.at(consumer_block_sref).region_cover = region_cover; break; @@ -332,7 +334,7 @@ class SBlockInfoCollector : private StmtVisitor { } void VisitStmt_(const ForNode* loop) final { - analyzer_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + analyzer_->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); PushSRef(loop); VisitStmt(loop->body); PopSRef(); diff --git a/src/s_tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc index 22465846e86c..21e902978081 100644 --- a/src/s_tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -28,7 +28,7 @@ Schedule Schedule::Traced(IRModule mod, LinearCongruentialEngine::TRandState see n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; - n->analyzer_ = std::make_unique(); + n->analyzer_ = arith::Analyzer(); n->trace_ = Trace(); n->Seed(seed); GlobalVar gv; @@ -45,7 +45,7 @@ Schedule TracedScheduleNode::Copy() { n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->func_working_on_ = this->func_working_on_; - n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful + n->analyzer_ = arith::Analyzer(); // new analyzer needed because it is stateful n->rand_state_ = ForkSeed(); n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); return Schedule(std::move(n)); diff --git a/src/s_tir/schedule/transform.cc b/src/s_tir/schedule/transform.cc index ee273597c841..bdbe8533373e 100644 --- a/src/s_tir/schedule/transform.cc +++ b/src/s_tir/schedule/transform.cc @@ -395,8 +395,8 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, const tirx::ForNode* desc_loop = kv.second.get(); TVM_FFI_ICHECK(block_loop != nullptr && desc_loop != nullptr); // Extract the loop extent - PrimExpr block_extent = analyzer.Simplify(block_loop->extent); - PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); + PrimExpr block_extent = analyzer->Simplify(block_loop->extent); + PrimExpr desc_extent = analyzer->Simplify(desc_loop->extent); const auto* int_block_extent = block_extent.as(); const auto* int_desc_extent = desc_extent.as(); TVM_FFI_ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); diff --git a/src/s_tir/schedule/transform.h b/src/s_tir/schedule/transform.h index 21e29b3e2170..6221cb35de05 100644 --- a/src/s_tir/schedule/transform.h +++ b/src/s_tir/schedule/transform.h @@ -236,13 +236,13 @@ class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer { * \param analyzer The arithmetic analyzer * \return The simplified statement */ - static Stmt Simplify(const Stmt& stmt, arith::Analyzer* analyzer) { + static Stmt Simplify(const Stmt& stmt, arith::AnalyzerObj* analyzer) { BlockBufferAccessSimplifier simplifier(analyzer); return simplifier(stmt); } private: - explicit BlockBufferAccessSimplifier(arith::Analyzer* analyzer) + explicit BlockBufferAccessSimplifier(arith::AnalyzerObj* analyzer) : IRMutatorWithAnalyzer(analyzer) {} using IRMutatorWithAnalyzer::VisitExpr_; diff --git a/src/s_tir/transform/bound_checker.cc b/src/s_tir/transform/bound_checker.cc index ba449ad19449..8f352e4888e2 100644 --- a/src/s_tir/transform/bound_checker.cc +++ b/src/s_tir/transform/bound_checker.cc @@ -206,8 +206,8 @@ class BoundChecker : public StmtExprMutator { } // Try to simplify index and bound. - index = analyzer_.Simplify(index); - upper_bound = analyzer_.Simplify(upper_bound); + index = analyzer_->Simplify(index); + upper_bound = analyzer_->Simplify(upper_bound); // Cast to the same type - signed, to be able to check lower bound. index = Cast(DataType::Int(64), index); diff --git a/src/s_tir/transform/canonicalize_loop.cc b/src/s_tir/transform/canonicalize_loop.cc index 5ee678789f80..9ecb242a10fe 100644 --- a/src/s_tir/transform/canonicalize_loop.cc +++ b/src/s_tir/transform/canonicalize_loop.cc @@ -50,7 +50,7 @@ class LoopCanonicalizer : public StmtExprMutator { PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1)); // report warning for negative step, since it would be a forever loop - if (!analyzer_.CanProveGreaterEqual(step, 1)) { + if (!analyzer_->CanProveGreaterEqual(step, 1)) { // TODO(tvm): prove dynamic shaped step TVM_FFI_THROW(InternalError) << "Loop step for " << op->loop_var << " may not be positive: " << step; @@ -60,7 +60,7 @@ class LoopCanonicalizer : public StmtExprMutator { auto n = CopyOnWrite(op); n->body = VisitStmt(op->body); n->min = make_zero(loop_var->dtype); - n->extent = analyzer_.Simplify(ceildiv(op->extent, step)); + n->extent = analyzer_->Simplify(ceildiv(op->extent, step)); n->step = std::nullopt; new_iter_info_.erase(loop_var); return For(n); diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index 566fa42cb8b5..d02e90701696 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -49,13 +49,14 @@ using support::NDIntSet; /*! \brief a more constrained bound estimate for n-dimentional int set */ NDIntSet NDIntSetEval(Region region, PrimExpr predicate, const std::unordered_map& dom_map, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { std::unordered_map var_dom; for (const auto& it : dom_map) { var_dom[ffi::GetRef(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0)); } + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); ffi::Optional> eval_res = - arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer); + arith::EstimateRegionUpperBound(region, var_dom, predicate, analyzer_ref); if (eval_res.defined()) { return NDIntSet(eval_res.value().begin(), eval_res.value().end()); @@ -166,7 +167,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { op->thread_binding.value()->thread_tag) : IterVar(Range(), op->loop_var, IterVarType::kDataPar); ancestor_iters_.push_back(iter); - dom_analyzer_.Bind(op->loop_var, loop_range); + dom_analyzer_->Bind(op->loop_var, loop_range); dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range)); size_t n_pending_before = pending_flat_alloc_buffers_.size(); StmtExprVisitor::VisitStmt_(op); @@ -179,7 +180,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitStmt_(const BindNode* op) final { StmtExprVisitor::VisitExpr(op->value); if (arith::IsIndexType(op->value->dtype)) { - dom_analyzer_.Bind(op->var, op->value); + dom_analyzer_->Bind(op->var, op->value); dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); } } @@ -187,7 +188,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitExpr_(const LetNode* op) final { StmtExprVisitor::VisitExpr(op->value); if (arith::IsIndexType(op->value->dtype)) { - dom_analyzer_.Bind(op->var, op->value); + dom_analyzer_->Bind(op->var, op->value); dom_map_.emplace(op->var.get(), arith::IntSet::SinglePoint(op->value)); } StmtExprVisitor::VisitExpr(op->body); @@ -321,7 +322,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { if (!dom.defined()) { // dom is empty for legacy te schedule dom = Range::FromMinExtent(make_zero(op->value->dtype), op->value); } - dom_analyzer_.Bind(iter->var, dom); + dom_analyzer_->Bind(iter->var, dom); dom_map_.emplace(iter->var.get(), arith::IntSet::FromRange(dom)); size_t n_pending_before = pending_flat_alloc_buffers_.size(); StmtExprVisitor::VisitStmt_(op); @@ -367,13 +368,13 @@ class BufferAccessRegionCollector : public StmtExprVisitor { if (pred->dtype.is_bool()) return pred; return pred != make_zero(pred->dtype); }; - PrimExpr predicate = dom_analyzer_.Simplify( + PrimExpr predicate = dom_analyzer_->Simplify( std::accumulate(pending_conditions_.begin(), pending_conditions_.end(), const_true(), [normalize_pred](const PrimExpr& x, const PrimExpr& y) { return normalize_pred(x) && normalize_pred(y); })); NDIntSet nd_int_set = - NDIntSetEval(buffer_region->region, predicate, dom_map_, &dom_analyzer_); + NDIntSetEval(buffer_region->region, predicate, dom_map_, dom_analyzer_.get()); // Step 3. Restore the non-relaxed ancestor loops domain for (size_t i = 0; i < n_ancestor_loops; ++i) { @@ -440,16 +441,16 @@ class BufferAccessRegionCollector : public StmtExprVisitor { Range range = int_set.CoverRange(original); PrimExpr min, extent; if (collect_inbound_) { - min = dom_analyzer_.Simplify(tvm::max(0, range->min)); + min = dom_analyzer_->Simplify(tvm::max(0, range->min)); extent = range->extent; // Apply stronger symbolic proof to help us remove symbolic min here. - if (!dom_analyzer_.CanProveLessEqualThanSymbolicShapeValue(extent, original_shape[i])) { + if (!dom_analyzer_->CanProveLessEqualThanSymbolicShapeValue(extent, original_shape[i])) { extent = tvm::min(original_shape[i], range->extent); } - extent = dom_analyzer_.Simplify(extent); + extent = dom_analyzer_->Simplify(extent); } else { - min = dom_analyzer_.Simplify(range->min); - extent = dom_analyzer_.Simplify(range->extent); + min = dom_analyzer_->Simplify(range->min); + extent = dom_analyzer_->Simplify(range->extent); } // We check the buffer extent is pure and not loop dependent, since loop dependent @@ -465,7 +466,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { }; if (UsesVar(extent, is_loop_var)) { // try estimate a constant upperbound on region's extent - int64_t upperbound = dom_analyzer_.const_int_bound(extent)->max_value; + int64_t upperbound = dom_analyzer_->const_int_bound(extent)->max_value; if (upperbound != arith::ConstIntBound::kPosInf) { extent = make_const(extent->dtype, upperbound); } else { diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 5cb851ca2a52..ac48593bd2a1 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -450,7 +450,7 @@ class ExpressionHoister : public arith::IRMutatorWithAnalyzer { auto loop_info = HoistInfoCollector::Collect(stmt, config); arith::Analyzer analyzer; - ExpressionHoister hoister(std::move(loop_info), config, &analyzer); + ExpressionHoister hoister(std::move(loop_info), config, analyzer.get()); stmt = hoister(std::move(stmt)); stmt = ConvertSSA(std::move(stmt)); return stmt; @@ -462,7 +462,7 @@ class ExpressionHoister : public arith::IRMutatorWithAnalyzer { using Parent::VisitStmt_; explicit ExpressionHoister(std::vector loop_info, - HoistExpressionConfig config, arith::Analyzer* analyzer) + HoistExpressionConfig config, arith::AnalyzerObj* analyzer) : Parent(analyzer), config_(config) { for (auto& info : loop_info) { // Mark let bindings to use if they are enabled on their own. diff --git a/src/s_tir/transform/inject_permuted_layout.cc b/src/s_tir/transform/inject_permuted_layout.cc index 4c5b7ad00803..fe90f38cec67 100644 --- a/src/s_tir/transform/inject_permuted_layout.cc +++ b/src/s_tir/transform/inject_permuted_layout.cc @@ -45,14 +45,14 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { static PrimFunc Transform(PrimFunc func) { Analyzer analyzer; - auto new_body = PermutedLayoutInjector(func, &analyzer)(func->body); + auto new_body = PermutedLayoutInjector(func, analyzer.get())(func->body); auto func_node = func.CopyOnWrite(); func_node->body = new_body; return func; } private: - explicit PermutedLayoutInjector(PrimFunc func, Analyzer* analyzer) + explicit PermutedLayoutInjector(PrimFunc func, AnalyzerObj* analyzer) : IRMutatorWithAnalyzer(analyzer) { buffer_map_.insert(func->buffer_map.begin(), func->buffer_map.end()); } diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index 79e3289d04be..d9da151f392f 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -374,7 +374,7 @@ class PipelineRewriter : public StmtExprMutator { // to ensure the epilogue interval do not overlap the prologue interval. PrimExpr epigogue_start = pipeline_loop_->min + pipeline_loop_->extent; ffi::Optional extra_epilogue_lower_bound = std::nullopt; - if (max_stage_ > 1 && !analyzer_.CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) { + if (max_stage_ > 1 && !analyzer_->CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) { if (is_const_int(epigogue_start)) { epigogue_start = max(epigogue_start, pipeline_loop_->min + max_stage_); } else { @@ -609,7 +609,7 @@ class PipelineRewriter : public StmtExprMutator { // Determine where to insert async_wait and the corresponding wait count. void PopulateWaitCounts(const std::vector& new_blocks, - arith::Analyzer* ana_normalized, + arith::AnalyzerObj* ana_normalized, const std::unordered_map& buffer_to_commit_group, std::map* async_states_local) { for (size_t i = 0; i < new_blocks.size(); ++i) { @@ -714,7 +714,7 @@ class PipelineRewriter : public StmtExprMutator { // Here, new_blocks[i].access_index corresponds to "consumer_head". // The difference of producer_head and consumer_head is precisely the number of // async commit groups that can still be in flight after this wait. - sum += analyzer_.Simplify(producer_head.value() - new_blocks[i].access_index); + sum += analyzer_->Simplify(producer_head.value() - new_blocks[i].access_index); } else { // The precise count cannot be determined, give up. return PrimExpr(0); @@ -727,7 +727,7 @@ class PipelineRewriter : public StmtExprMutator { if (!pending_wait.valid()) { pending_wait = {static_cast(i), wait_count}; - } else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) { + } else if (analyzer_->CanProve(wait_count < pending_wait.wait_count)) { // Coalesce multiple wait_queue if the later one allows fewer in-flight ops. pending_wait = {pending_wait.insert_before, wait_count}; } @@ -739,7 +739,7 @@ class PipelineRewriter : public StmtExprMutator { ffi::Array CompletePipelineLoopStatements( const std::vector& blocks, const std::map& async_states_local, - arith::Analyzer* ana_normalized) const { + arith::AnalyzerObj* ana_normalized) const { std::vector new_blocks = blocks; std::vector commit_group_indices(new_blocks.size(), -1); for (const auto& [stage_id, state] : async_states_local) { @@ -826,22 +826,22 @@ class PipelineRewriter : public StmtExprMutator { auto make_nop = []() { return SBlockRealize({}, const_true(), MakeSBlock(Evaluate(0), {})); }; - if (analyzer_.CanProve(extent <= 0)) { + if (analyzer_->CanProve(extent <= 0)) { return make_nop(); } - bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); + bool is_unit_loop = analyzer_->CanProveEqual(extent, 1); if (is_unit_loop) { new_loop_var = start; // use constants as the loop var for unit loops } else { new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); - analyzer_.Bind(Downcast(new_loop_var), Range(start, end)); + analyzer_->Bind(Downcast(new_loop_var), Range(start, end)); } // In contrast to analyzer_ which is bound to [start, end), this one is bound to // the "normalized" range, [pipeline_loop_->min, extent). arith::Analyzer ana_normalized; if (!is_unit_loop) { - ana_normalized.Bind(Downcast(new_loop_var), Range(pipeline_loop_->min, extent)); + ana_normalized->Bind(Downcast(new_loop_var), Range(pipeline_loop_->min, extent)); } std::vector new_blocks; @@ -853,12 +853,12 @@ class PipelineRewriter : public StmtExprMutator { for (const SBlock& block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; PrimExpr skewed_loop_var = new_loop_var - stage; - PrimExpr inbound = analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && + PrimExpr inbound = analyzer_->Simplify(pipeline_loop_->min <= skewed_loop_var) && (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); if (extra_loop_lower_bound.defined()) { - inbound = analyzer_.Simplify(inbound && new_loop_var >= extra_loop_lower_bound.value()); + inbound = analyzer_->Simplify(inbound && new_loop_var >= extra_loop_lower_bound.value()); } - if (analyzer_.CanProve(!inbound)) { + if (analyzer_->CanProve(!inbound)) { continue; } SBlock new_block = Downcast( @@ -910,10 +910,10 @@ class PipelineRewriter : public StmtExprMutator { local_state.producer_head = normalized_access_index; - if (!local_state.predicate || ana_normalized.CanProve(local_state.predicate.value())) { + if (!local_state.predicate || ana_normalized->CanProve(local_state.predicate.value())) { local_state.predicate = inbound; } else if (local_state.predicate) { - local_state.predicate = ana_normalized.Simplify(local_state.predicate.value() & inbound); + local_state.predicate = ana_normalized->Simplify(local_state.predicate.value() & inbound); } SBlockNode* n = new_block.CopyOnWrite(); @@ -933,8 +933,10 @@ class PipelineRewriter : public StmtExprMutator { } } - PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local); - auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized); + PopulateWaitCounts(new_blocks, ana_normalized.get(), buffer_to_commit_group, + &async_states_local); + auto stmts = + CompletePipelineLoopStatements(new_blocks, async_states_local, ana_normalized.get()); Stmt new_loop{nullptr}; @@ -958,7 +960,7 @@ class PipelineRewriter : public StmtExprMutator { const int stage_id = kv.first; const AsyncStateLocal& state = kv.second; - if (state.predicate && ana_normalized.CanProve(state.predicate.value()) && + if (state.predicate && ana_normalized->CanProve(state.predicate.value()) && async_states[stage_id].producer_head) { // Advance the "global" producer head if it is still valid and we know exactly how much we // can increment diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index f573732cb9f8..61ec03ce352c 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -183,7 +183,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt_; // constructor - VTInjector(arith::Analyzer* analyzer, Var var, int num_threads, + VTInjector(arith::AnalyzerObj* analyzer, Var var, int num_threads, const std::unordered_set& touched_var, bool allow_share) : IRMutatorWithAnalyzer(analyzer), var_(var), @@ -541,7 +541,7 @@ Pass InjectVirtualThread() { arith::Analyzer analyzer; - n->body = VirtualThreadInjector(&analyzer)(std::move(n->body)); + n->body = VirtualThreadInjector(analyzer.get())(std::move(n->body)); n->body = ConvertSSA(std::move(n->body)); return f; }; diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index 8eb444dcfd53..c59453d41417 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -156,7 +156,7 @@ class CandidateSelector final : public StmtExprVisitor { return; } } else if (op->attr_key == s_tir::attr::pragma_loop_partition_hint) { - if (analyzer_.CanProve(op->value)) { + if (analyzer_->CanProve(op->value)) { const VarNode* var = nullptr; if (op->node.as()) { var = op->node.as(); @@ -424,7 +424,7 @@ class LoopPartitioner : public StmtMutator { } Stmt VisitStmt_(const ForNode* op) final { - analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true); + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent), true); auto fs = ffi::GetRef(op); if (selector.candidates.count(fs)) { Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false); @@ -499,7 +499,7 @@ std::pair LoopPartitioner::GetIntervalAndCondset( for (const auto& kv : partitions) { if (kv.first.second == cond_value) { arith::IntervalSet interval = Downcast(kv.second); - arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval); + arith::IntervalSet intersection = arith::Intersect(analyzer_.get(), interval, for_interval); if (!intersection->IsEmpty()) { sets.push_back(kv.second); @@ -518,14 +518,15 @@ std::pair LoopPartitioner::GetIntervalAndCondset( for (const auto& kv : partitions) { if (kv.first.second == cond_value) { arith::IntervalSet cond_interval = Downcast(kv.second); - arith::IntervalSet intersection = arith::Intersect(&analyzer_, cond_interval, for_interval); + arith::IntervalSet intersection = + arith::Intersect(analyzer_.get(), cond_interval, for_interval); if (!intersection->IsEmpty()) { - cond_intersection = arith::Intersect(&analyzer_, cond_intersection, cond_interval); + cond_intersection = arith::Intersect(analyzer_.get(), cond_intersection, cond_interval); // Return the latest interval and cond_set if the cond_intersection is nothing. if (!cond_intersection->IsEmpty()) { cond_set.insert(kv.first.first); - interval = arith::IntervalSet(analyzer_.Simplify(cond_intersection->min_value), - analyzer_.Simplify(cond_intersection->max_value)); + interval = arith::IntervalSet(analyzer_->Simplify(cond_intersection->min_value), + analyzer_->Simplify(cond_intersection->max_value)); } else { break; } @@ -629,8 +630,8 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim if (intset.IsSinglePoint()) { auto single_point = intset.PointValue(); // Check if the single point is outside the `for_interval` - bool is_inside = analyzer_.CanProve(single_point >= for_interval.min()) && - analyzer_.CanProve(single_point <= for_interval.max()); + bool is_inside = analyzer_->CanProve(single_point >= for_interval.min()) && + analyzer_->CanProve(single_point <= for_interval.max()); if (is_inside) { // If any single point is inside, this is an error condition LOG(ERROR) << "unexpected case happened."; @@ -662,7 +663,7 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim if (!opt_cond_value.has_value()) { if (has_partition_hint_ && unroll_loop_with_partition_hint_no_interval_ && - analyzer_.CanProve(max - min > 0)) { + analyzer_->CanProve(max - min > 0)) { auto new_body = VisitAndMutate(body); return For(var, min, max - min + 1, ForKind::kUnrolled, new_body); } @@ -682,15 +683,15 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim Stmt pre_stmt; bool pre_stmt_recurse = true; if (middle_interval_i->HasLowerBound()) { - body_begin = analyzer_.Simplify(middle_interval.min()); - if (!analyzer_.CanProve(body_begin == min)) { - PrimExpr extent = analyzer_.Simplify(body_begin - min); - if (!analyzer_.CanProve(extent > 0)) { + body_begin = analyzer_->Simplify(middle_interval.min()); + if (!analyzer_->CanProve(body_begin == min)) { + PrimExpr extent = analyzer_->Simplify(body_begin - min); + if (!analyzer_->CanProve(extent > 0)) { body_begin = tvm::max(body_begin, min); // stop recursing on this interval if we can't prove it has non-negative length pre_stmt_recurse = false; } - if (!analyzer_.CanProve(extent <= 0)) { + if (!analyzer_->CanProve(extent <= 0)) { if (!partition_thread_scope) { Stmt pre_body = Substitute(body, {{Var{var}, var + min}}); pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body); @@ -707,16 +708,16 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim Stmt post_stmt; bool post_stmt_recurse = true; if (middle_interval_i->HasUpperBound()) { - post_doubt_begin = analyzer_.Simplify(middle_interval.max() + 1); - if (!analyzer_.CanProve(middle_interval.max() == max)) { + post_doubt_begin = analyzer_->Simplify(middle_interval.max() + 1); + if (!analyzer_->CanProve(middle_interval.max() == max)) { // require the extent to be non-negative - PrimExpr extent = analyzer_.Simplify(max - post_doubt_begin + 1); - if (!analyzer_.CanProve(extent > 0)) { + PrimExpr extent = analyzer_->Simplify(max - post_doubt_begin + 1); + if (!analyzer_->CanProve(extent > 0)) { post_doubt_begin = tvm::min(post_doubt_begin, max + 1); // stop recursing on this interval if we can't prove it has non-negative length post_stmt_recurse = false; } - if (!analyzer_.CanProve(extent <= 0)) { + if (!analyzer_->CanProve(extent <= 0)) { if (!partition_thread_scope) { Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}}); post_stmt = MakeFor(stmt.get(), extent, post_body); @@ -732,7 +733,7 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim // Generating code for middle subrange if (!partition_thread_scope) { Stmt mid_stmt; - if (!analyzer_.CanProve(body_begin >= post_doubt_begin)) { + if (!analyzer_->CanProve(body_begin >= post_doubt_begin)) { // [body_begin, post_doubt_begin) Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body); Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}}); @@ -753,8 +754,9 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim s = SeqStmt::Flatten(pre_stmt, mid_stmt, post_stmt); } else { PrimExpr cond = const_true(); - if (!analyzer_.CanProve(body_begin == min)) cond = cond && (var >= body_begin); - if (!analyzer_.CanProve(post_doubt_begin == (max + 1))) cond = cond && (var < post_doubt_begin); + if (!analyzer_->CanProve(body_begin == min)) cond = cond && (var >= body_begin); + if (!analyzer_->CanProve(post_doubt_begin == (max + 1))) + cond = cond && (var < post_doubt_begin); s = ThreadPartitionInserter(cond_set, cond)(stmt); } s = ConvertSSA(s); @@ -765,7 +767,7 @@ inline Stmt LoopPartitioner::MakeFor(const ffi::Object* node, PrimExpr extent, S const ForNode* for_node = static_cast(node); TVM_FFI_ICHECK(for_node); - if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) && + if (analyzer_->CanProve(extent == make_const(DataType::Int(32), 1)) && !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) { // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); diff --git a/src/s_tir/transform/lower_async_dma.cc b/src/s_tir/transform/lower_async_dma.cc index 6833f989f801..1178c1aa48c3 100644 --- a/src/s_tir/transform/lower_async_dma.cc +++ b/src/s_tir/transform/lower_async_dma.cc @@ -46,7 +46,7 @@ using namespace tvm::tirx; class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { public: - explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer) + explicit AsyncDMALowerer(bool dma_bypass_cache, arith::AnalyzerObj* analyzer) : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} // TODO(leiwang1999): split lower async DMA support for CUDA and Hexagon Backend @@ -58,7 +58,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { // if for loop is not a memcpy of a contiguous region, it might be a cuda cp.async behavior std::optional mem_copy = - s_tir::IdentifyMemCpy(ffi::GetRef(loop), analyzer_); + s_tir::IdentifyMemCpy(ffi::GetRef(loop), ffi::GetRef(analyzer_)); if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || mem_copy->source->region.size() != 1) { return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); @@ -176,7 +176,7 @@ Pass LowerAsyncDMA() { arith::Analyzer analyzer; bool dma_bypass_cache = ctx->GetConfig("tirx.experimental_dma_bypass_cache", false).value(); - fptr->body = AsyncDMALowerer(dma_bypass_cache, &analyzer)(std::move(fptr->body)); + fptr->body = AsyncDMALowerer(dma_bypass_cache, analyzer.get())(std::move(fptr->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "s_tir.LowerAsyncDMA", {}); diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc index 361466a2f6a1..18ff343d4dff 100644 --- a/src/s_tir/transform/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -109,7 +109,7 @@ bool IsDominantBlock(const SBlock& scope_block, const SBlock& block) { * check again. */ bool IsReductionBlock(const SBlockRealize& realize, const ffi::Map& loop_range_map, - const SBlock& scope_block, arith::Analyzer* analyzer) { + const SBlock& scope_block, arith::AnalyzerObj* analyzer) { const auto* block = realize->block.as(); // Cond 1. The block has the `init` statement. if (!block->init.defined()) { @@ -548,7 +548,7 @@ class CrossThreadReductionTransformer : public StmtMutator { // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. if (!IsReductionBlock(ffi::GetRef(realize), loop_range_map_, - ffi::GetRef(block_stack_.back()), &analyzer_)) { + ffi::GetRef(block_stack_.back()), analyzer_.get())) { return {}; } diff --git a/src/s_tir/transform/lower_match_buffer.cc b/src/s_tir/transform/lower_match_buffer.cc index 4caa02bc713c..17844c0a286f 100644 --- a/src/s_tir/transform/lower_match_buffer.cc +++ b/src/s_tir/transform/lower_match_buffer.cc @@ -89,7 +89,7 @@ class MatchBufferLower : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) final { - analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); return StmtExprMutator::VisitStmt_(op); } @@ -205,7 +205,7 @@ class MatchBufferLower : public StmtExprMutator { if (buffer_start_indices.size() == 1) { Bind(buffer->elem_offset, buffer_start_indices[0], buffer->name + ".elem_offset"); TVM_FFI_ICHECK( - analyzer_.CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) + analyzer_->CanProve(truncmod(buffer->elem_offset, buffer->offset_factor) == 0)) << "The source elem_offset " << buffer_start_indices[0] << " does not satisfy the offset_factor " << buffer->offset_factor << "."; } else { @@ -262,7 +262,7 @@ class MatchBufferLower : public StmtExprMutator { auto it = var_map_.find(v); if (it == var_map_.end()) { var_map_.Set(v, value); - analyzer_.Bind(v, value); + analyzer_->Bind(v, value); } else { AssertBinding((*it).second, value, arg_name); } @@ -273,8 +273,9 @@ class MatchBufferLower : public StmtExprMutator { void AssertBinding(const PrimExpr& lhs, const PrimExpr& rhs, const std::string& arg_name = "argument") { - TVM_FFI_ICHECK(analyzer_.CanProve(lhs == rhs)) << "The buffer match constraint for " << arg_name - << " unmet: " << lhs << "==" << rhs << "."; + TVM_FFI_ICHECK(analyzer_->CanProve(lhs == rhs)) + << "The buffer match constraint for " << arg_name << " unmet: " << lhs << "==" << rhs + << "."; } private: diff --git a/src/s_tir/transform/lower_thread_allreduce.cc b/src/s_tir/transform/lower_thread_allreduce.cc index 14e4d5f1b4f0..98c6a363d9dd 100644 --- a/src/s_tir/transform/lower_thread_allreduce.cc +++ b/src/s_tir/transform/lower_thread_allreduce.cc @@ -710,7 +710,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The local buffer index. PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { if (!is_zero(group_index)) { - return analyzer_.Simplify(group_index * reduce_extent + reduce_index); + return analyzer_->Simplify(group_index * reduce_extent + reduce_index); } else { return reduce_index; } diff --git a/src/s_tir/transform/memhammer_coalesce.cc b/src/s_tir/transform/memhammer_coalesce.cc index fb67c3eae1b0..7f65941fa1fa 100644 --- a/src/s_tir/transform/memhammer_coalesce.cc +++ b/src/s_tir/transform/memhammer_coalesce.cc @@ -105,7 +105,7 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { for (int i = 0; i < n; i++) { const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); - analyzer.Bind(var, Range::FromMinExtent(0, factor)); + analyzer->Bind(var, Range::FromMinExtent(0, factor)); new_loop_vars.push_back(var); } // substitute fused loop var with new loop vars @@ -123,7 +123,7 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { } }); PrimExpr predicate = substitute_value < loop->extent; - if (!analyzer.CanProve(predicate)) { + if (!analyzer->CanProve(predicate)) { body = IfThenElse(predicate, body); } body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); @@ -167,7 +167,7 @@ ffi::Array GetMapping(const Stmt& stmt, const ConstraintSet& constrain ffi::Array result; arith::Analyzer analyzer; for (int i = 0; i < static_cast(write_region->region.size()); i++) { - PrimExpr pattern = analyzer.Simplify(write_index[i] - write_region->region[i]->min); + PrimExpr pattern = analyzer->Simplify(write_index[i] - write_region->region[i]->min); if (!is_zero(pattern)) { result.push_back(pattern); } @@ -191,7 +191,7 @@ Stmt InverseMapping::Rewrite(const Stmt& stmt, const ConstraintSet& constraints, arith::Analyzer analyzer; DiagnosticContext diag_ctx(DiagnosticContext::Default(IRModule())); auto iter_map = - arith::DetectIterMap(mapping_pattern, var_range, const_true(), arith::Bijective, &analyzer); + arith::DetectIterMap(mapping_pattern, var_range, const_true(), arith::Bijective, analyzer); TVM_FFI_ICHECK_EQ(iter_map->indices.size(), loop_vars.size()); ffi::Map inverse_mapping = arith::InverseAffineIterMap(iter_map->indices, loop_vars); diff --git a/src/s_tir/transform/memhammer_intermediate_stage.cc b/src/s_tir/transform/memhammer_intermediate_stage.cc index 63e51cd7b8f9..0d410f016c52 100644 --- a/src/s_tir/transform/memhammer_intermediate_stage.cc +++ b/src/s_tir/transform/memhammer_intermediate_stage.cc @@ -293,7 +293,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::S TVM_FFI_ICHECK(target_buffer_load->indices.size() == buffer_load->indices.size()); for (size_t i = 0; i < target_buffer_load->indices.size(); i++) { TVM_FFI_ICHECK( - analyzer.CanProveEqual(target_buffer_load->indices[i], buffer_load->indices[i])); + analyzer->CanProveEqual(target_buffer_load->indices[i], buffer_load->indices[i])); } } } diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 3db122b2ea4e..af805d64f7eb 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -477,7 +477,7 @@ class AutoPadder { } }); arith::Analyzer analyzer; - return !analyzer.CanProve(Substitute(e2 - e1, subst_map) != 1); + return !analyzer->CanProve(Substitute(e2 - e1, subst_map) != 1); } void VisitStmt_(const ForNode* op) final { @@ -514,7 +514,7 @@ class AutoPadder { ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : op->indices) { - substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + substitued_indices.push_back(analyzer->Simplify(Substitute(e, substitute_map_))); } std::vector> iter_space = PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); @@ -542,7 +542,7 @@ class AutoPadder { ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : op->indices) { - substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + substitued_indices.push_back(analyzer->Simplify(Substitute(e, substitute_map_))); } std::vector> iter_space = PatternCollector::CollectIterationSpace(substitued_indices, var_range_, data_bits_); @@ -584,7 +584,7 @@ class AutoPadder { ffi::Array substitued_indices; arith::Analyzer analyzer; for (const PrimExpr& e : indices) { - substitued_indices.push_back(analyzer.Simplify(Substitute(e, substitute_map_))); + substitued_indices.push_back(analyzer->Simplify(Substitute(e, substitute_map_))); } std::vector> iter_space = PatternCollector::CollectIterationSpace( substitued_indices, var_range_, data_bits_); diff --git a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc index 1a4532b8a4aa..5a3b48521873 100644 --- a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc +++ b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc @@ -43,8 +43,8 @@ std::pair> TileWmmaBlock(Stmt stmt) { PrimExpr extent_last2 = loops[n - 2]->extent; { arith::Analyzer analyzer; - if (!analyzer.CanProveEqual(floormod(extent_last1, 16), 0) || - !analyzer.CanProveEqual(floormod(extent_last2, 16), 0)) { + if (!analyzer->CanProveEqual(floormod(extent_last1, 16), 0) || + !analyzer->CanProveEqual(floormod(extent_last2, 16), 0)) { return std::make_pair(stmt, std::nullopt); } } @@ -371,8 +371,8 @@ std::pair> TileMmaToGlobalBlock(Stmt stmt) { { arith::Analyzer analyzer; // Only tile when both extent % 8 == 0 - if (!analyzer.CanProveEqual(floormod(extent_last1, 8), 0) || - !analyzer.CanProveEqual(floormod(extent_last2, 8), 0)) { + if (!analyzer->CanProveEqual(floormod(extent_last1, 8), 0) || + !analyzer->CanProveEqual(floormod(extent_last2, 8), 0)) { return std::make_pair(stmt, std::nullopt); } } diff --git a/src/s_tir/transform/renormalize_split_pattern.cc b/src/s_tir/transform/renormalize_split_pattern.cc index ae3d048b8892..f185e66a0731 100644 --- a/src/s_tir/transform/renormalize_split_pattern.cc +++ b/src/s_tir/transform/renormalize_split_pattern.cc @@ -52,7 +52,7 @@ using namespace arith; class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { public: - explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} + explicit SplitPatternReNormalizer(AnalyzerObj* analyzer) : IRMutatorWithAnalyzer(analyzer) {} using IRMutatorWithAnalyzer::VisitExpr_; @@ -201,7 +201,7 @@ Pass RenormalizeSplitPattern() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); arith::Analyzer analyzer; - n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body)); + n->body = SplitPatternReNormalizer(analyzer.get())(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "s_tir.RenormalizeSplitPattern", {}); diff --git a/src/s_tir/transform/transform_mma_buffer_layout.cc b/src/s_tir/transform/transform_mma_buffer_layout.cc index d3518ccd81ca..e15145180cf4 100644 --- a/src/s_tir/transform/transform_mma_buffer_layout.cc +++ b/src/s_tir/transform/transform_mma_buffer_layout.cc @@ -135,7 +135,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { const auto index_map_func = tvm::ffi::Function::GetGlobal("tirx.index_map_m16n8k8.matrixC"); TVM_FFI_ICHECK(index_map_func.has_value()); auto index_map = IndexMap::FromFunc(2, *index_map_func); - auto new_indices = index_map->MapIndices(store->indices, &analyzer); + auto new_indices = index_map->MapIndices(store->indices, analyzer); n->buffer = buffer_map_[store->buffer]; n->indices = std::move(new_indices); } else if (store->buffer.scope() == "m16n8k8.matrixA" || @@ -154,7 +154,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { const auto index_map_func = tvm::ffi::Function::GetGlobal("tirx.index_map_m16n8k8.matrixC"); TVM_FFI_ICHECK(index_map_func.has_value()); auto index_map = IndexMap::FromFunc(2, *index_map_func); - auto new_indices = index_map->MapIndices(load->indices, &analyzer); + auto new_indices = index_map->MapIndices(load->indices, analyzer); n->buffer = buffer_map_[load->buffer]; n->indices = std::move(new_indices); } else if (load->buffer.scope() == "m16n8k8.matrixA" || diff --git a/src/s_tir/transform/unify_thread_binding.cc b/src/s_tir/transform/unify_thread_binding.cc index 85333b6efcaf..ec2f9ebc6fad 100644 --- a/src/s_tir/transform/unify_thread_binding.cc +++ b/src/s_tir/transform/unify_thread_binding.cc @@ -115,8 +115,8 @@ class ThreadBindingUnifier : public StmtExprMutator { ffi::Map::iterator it = thread_tag2iter_var_map_.find(thread_tag); if (it != thread_tag2iter_var_map_.end()) { new_iter_var = (*it).second; - TVM_FFI_ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); - TVM_FFI_CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent), ValueError) + TVM_FFI_ICHECK(ana->CanProveEqual(dom->min, new_iter_var->dom->min)); + TVM_FFI_CHECK(ana->CanProveEqual(dom->extent, new_iter_var->dom->extent), ValueError) << "All loops that are bound to `" << thread_tag << "` should have the same extent. However, there are two loops with extent " << new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal"; diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc b/src/s_tir/transform/using_assume_to_reduce_branches.cc index 672769949c03..0935ab5faafb 100644 --- a/src/s_tir/transform/using_assume_to_reduce_branches.cc +++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc @@ -115,7 +115,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { public: using Parent = IRMutatorWithAnalyzer; - explicit ParseAssumeAndOvercompute(Analyzer* analyzer) : Parent(analyzer) {} + explicit ParseAssumeAndOvercompute(AnalyzerObj* analyzer) : Parent(analyzer) {} private: using Parent::VisitExpr_; @@ -380,7 +380,7 @@ Pass UseAssumeToReduceBranches() { if (assume_checker.has_assume) { // Leverage from assume and eliminate the branch - ParseAssumeAndOvercompute func_analyzer_mutator(&analyzer); + ParseAssumeAndOvercompute func_analyzer_mutator(analyzer.get()); n->body = func_analyzer_mutator(std::move(n->body)); } } diff --git a/src/target/cuda/codegen_cuda.cc b/src/target/cuda/codegen_cuda.cc index fd7a250c120f..756bd2a360c7 100644 --- a/src/target/cuda/codegen_cuda.cc +++ b/src/target/cuda/codegen_cuda.cc @@ -209,10 +209,10 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { is_persistent = true; } arith::Analyzer analyzer; - PrimExpr threadIdx_ext = analyzer.Simplify(extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * - extractor.threadIdx_z_ext); + PrimExpr threadIdx_ext = analyzer->Simplify( + extractor.threadIdx_x_ext * extractor.threadIdx_y_ext * extractor.threadIdx_z_ext); PrimExpr cluster_cta_yz_ext = - analyzer.Simplify(extractor.clusterCtaIdx_y_ext * extractor.clusterCtaIdx_z_ext); + analyzer->Simplify(extractor.clusterCtaIdx_y_ext * extractor.clusterCtaIdx_z_ext); if (const IntImmNode* const cluster_cta_yz_ext_int = cluster_cta_yz_ext.as()) { cluster_cta_x_is_linear_rank_ = cluster_cta_yz_ext_int->value == 1; } else { @@ -1102,7 +1102,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { arith::Analyzer analyzer; auto inverse_index_map = - IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, &analyzer); + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, analyzer); auto indices_16x16 = inverse_index_map->final_indices; // "//" and "%" in the index map are translated to FloorDiv/Mod, but the plain Div/Mod are fine. @@ -1206,7 +1206,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { arith::Analyzer analyzer; auto inverse_index_map = - IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, &analyzer); + IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0, n)}, analyzer); auto indices_16x16 = inverse_index_map->final_indices; class LowerFloorDivMod : public ExprMutator { diff --git a/src/target/hexagon/llvm/codegen_hexagon.cc b/src/target/hexagon/llvm/codegen_hexagon.cc index e0beb0262752..a5503d209ba7 100644 --- a/src/target/hexagon/llvm/codegen_hexagon.cc +++ b/src/target/hexagon/llvm/codegen_hexagon.cc @@ -326,7 +326,7 @@ llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType buffer_typ if (buffer_type.bits() != 8) return nullptr; - int table_elem_count = arith::Analyzer().Simplify(buffer->shape[0]).as()->value; + int table_elem_count = arith::Analyzer()->Simplify(buffer->shape[0]).as()->value; if (table_elem_count <= 0 || table_elem_count > 256) return nullptr; auto int32 = DataType::Int(32); diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 10a129eca74f..ea30f272712f 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -519,7 +519,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { llvm::DISubprogram* di_subprogram_{nullptr}; std::unordered_map var_map_; std::vector> loop_frame_jump_tgts_; - std::unique_ptr analyzer_{std::make_unique()}; + arith::Analyzer analyzer_{arith::Analyzer()}; CodeGenCPU* parent_; }; @@ -663,7 +663,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin builder_->CreateInBoundsGEP(t_tvm_parallel_group_env_, penv, {ConstInt32(0), ConstInt32(1)}), "num_task"); par_env.penv = penv; - auto new_analyzer = std::make_unique(); + auto new_analyzer = arith::Analyzer(); std::swap(function_, f); std::swap(parallel_env_, par_env); std::swap(analyzer_, new_analyzer); @@ -716,7 +716,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod std::unordered_map new_vmap; UnpackClosureData(cdata, vfields, &new_vmap); TVM_FFI_ICHECK(parallel_env_.penv == nullptr); - auto new_analyzer = std::make_unique(); + auto new_analyzer = arith::Analyzer(); std::swap(function_, f); std::swap(analyzer_, new_analyzer); std::swap(var_map_, new_vmap); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 44308be5ba2f..9e15505fb44c 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -220,7 +220,7 @@ void CodeGenLLVM::InitFuncState() { alias_var_set_.clear(); alloc_storage_info_.clear(); volatile_buf_.clear(); - analyzer_.reset(new arith::Analyzer()); + analyzer_ = arith::Analyzer(); } std::tuple CodeGenLLVM::GetLinkage( diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 61d7da8ce402..d022479e4278 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -551,7 +551,7 @@ class CodeGenLLVM : public ExprFunctor, // Whether current function is restricted bool is_restricted_{true}; // The analyzer information - std::unique_ptr analyzer_; + arith::Analyzer analyzer_; // set of var that are not restricted(can alias) std::unordered_set alias_var_set_; // set of volatile buffer. diff --git a/src/target/opencl/intrin_rule_opencl.cc b/src/target/opencl/intrin_rule_opencl.cc index ba1873bde694..6f76af4b0e35 100644 --- a/src/target/opencl/intrin_rule_opencl.cc +++ b/src/target/opencl/intrin_rule_opencl.cc @@ -116,7 +116,7 @@ static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size arith::Analyzer analyzer; - TVM_FFI_ICHECK(analyzer.CanProve(call->args[3] == call->args[4])) + TVM_FFI_ICHECK(analyzer->CanProve(call->args[3] == call->args[4])) << "Intel warp shuffle dose not support width != warp_size"; ffi::Array opencl_args{ {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a52c9fc4c2ef..da09c4b21a94 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -868,7 +868,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { const RampNode* ramp = index.as(); TVM_FFI_ICHECK(ramp); - arith::ModularSet me = arith::Analyzer().modular_set(ramp->base); + arith::ModularSet me = arith::Analyzer()->modular_set(ramp->base); // The condition: {k * coeff + base} divisible by the alignment for any k if (me->coeff % op->dtype.lanes() == 0 && me->base % op->dtype.lanes() == 0) { can_vector_load = true; @@ -1243,7 +1243,7 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) { void CodeGenC::VisitStmt_(const ForNode* op) { std::string begin_str = PrintExpr(op->min); - PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer().Simplify(op->min + op->extent); + PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer()->Simplify(op->min + op->extent); std::string end_str = PrintExpr(end); std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : ""; PrintIndent(); diff --git a/src/target/vulkan/codegen_spirv.cc b/src/target/vulkan/codegen_spirv.cc index 7e9fa2b8a3df..0afd35026916 100644 --- a/src/target/vulkan/codegen_spirv.cc +++ b/src/target/vulkan/codegen_spirv.cc @@ -129,7 +129,7 @@ void CodeGenSPIRV::InitFuncState() { std::fill(workgroup_size_, workgroup_size_ + 3, 1); var_map_.clear(); storage_info_.clear(); - analyzer_.reset(new arith::Analyzer()); + analyzer_ = arith::Analyzer(); builder_.reset(new spirv::IRBuilder(spirv_support_)); builder_->InitHeader(); shared_memory_bytes_used_ = 0; diff --git a/src/target/vulkan/codegen_spirv.h b/src/target/vulkan/codegen_spirv.h index cea634d8ab42..e0e41b9b1526 100644 --- a/src/target/vulkan/codegen_spirv.h +++ b/src/target/vulkan/codegen_spirv.h @@ -222,7 +222,7 @@ class CodeGenSPIRV : public ExprFunctor, std::unordered_map var_map_; // The analyzer. - std::unique_ptr analyzer_; + arith::Analyzer analyzer_; // deep comparison of PrimExpr ExprDeepEqual deep_equal_; diff --git a/src/target/webgpu/codegen_webgpu.cc b/src/target/webgpu/codegen_webgpu.cc index fcec71d9de1b..48e4cc87b60e 100644 --- a/src/target/webgpu/codegen_webgpu.cc +++ b/src/target/webgpu/codegen_webgpu.cc @@ -688,7 +688,7 @@ void CodeGenWebGPU::VisitStmt_(const AllocBufferNode* op) { void CodeGenWebGPU::VisitStmt_(const ForNode* op) { std::string begin_str = PrintExpr(op->min); - PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer().Simplify(op->min + op->extent); + PrimExpr end = is_zero(op->min) ? op->extent : arith::Analyzer()->Simplify(op->min + op->extent); std::string end_str = PrintExpr(end); std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : ""; std::string vid = AllocVarID(op->loop_var.get()); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 14a0549ecb1d..a4ce62812a08 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -192,7 +192,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { using NestedIterLevels = std::vector>; NestedIterLevels GenerateNestedIterLevels(const ffi::Array& axes, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { int global_max_depth = 0; std::unordered_map depth; std::unordered_map var2iter; @@ -364,7 +364,7 @@ Stmt GenerateInitStmt(const ffi::Array& indices, const ffi::Array& indices, const ffi::Array& buffers, const ffi::Map& var_map, PrimExpr expr_body, - CreateFuncInfo* info, arith::Analyzer* analyzer) { + CreateFuncInfo* info, arith::AnalyzerObj* analyzer) { // helper to transform the expr and remap iters to the block domain auto f_transform_and_remap = [&](const PrimExpr& e) { return Substitute(info->transformer(e), var_map); @@ -476,7 +476,7 @@ struct NestedScopeInfo { }; Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* info, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { // Step 1. Collect all iter axes in original TE compute op ffi::Array axes = compute_op->axis; axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end()); @@ -707,7 +707,7 @@ void InitializeBufferBinds(const ffi::Array& ordered_ops, CreateF } void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, - ffi::Array* root_stmts, arith::Analyzer* analyzer) { + ffi::Array* root_stmts, arith::AnalyzerObj* analyzer) { if (const auto* placeholder = op.as()) { // Case 1. PlaceholderOp (te.placeholder) TVM_FFI_ICHECK_EQ(op->num_outputs(), 1); @@ -776,7 +776,7 @@ PrimFunc CreatePrimFunc(const ffi::Array& arg_list, // Step 3. Rewrite compute stages into blocks. for (const te::Operation& op : order) { - RewriteStageToBlock(op, &info, &root_stmts, &analyzer); + RewriteStageToBlock(op, &info, &root_stmts, analyzer.get()); } // Step 4. Create func and complete prim func. @@ -854,7 +854,7 @@ PrimFunc CreatePrimFunc(const ffi::Array& arg_list, // Step 3. Rewrite compute stages into blocks. for (const te::Operation& op : order) { - RewriteStageToBlock(op, &info, &root_stmts, &analyzer); + RewriteStageToBlock(op, &info, &root_stmts, analyzer.get()); } auto func = GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); if (index_dtype_override.has_value()) { diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index bfee2b42227f..5e8d4361ec85 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -55,7 +55,7 @@ ScanOp::ScanOp(std::string name, std::string tag, TVM_FFI_ICHECK_EQ(init.size(), state_placeholder.size()); arith::Analyzer analyzer; auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) { - return is_zero(analyzer.Simplify(lhs - rhs)); + return is_zero(analyzer->Simplify(lhs - rhs)); }; for (size_t i = 0; i < init.size(); ++i) { diff --git a/src/tirx/analysis/exec_context.cc b/src/tirx/analysis/exec_context.cc index 93c2781da210..3ede47b31f95 100644 --- a/src/tirx/analysis/exec_context.cc +++ b/src/tirx/analysis/exec_context.cc @@ -55,7 +55,7 @@ bool TryAsInt64(const PrimExpr& expr, int64_t* value) { bool IsZero(const PrimExpr& expr) { arith::Analyzer analyzer; - return analyzer.CanProveEqual(expr, 0); + return analyzer->CanProveEqual(expr, 0); } ActiveSet MakeActiveSet(const std::vector>& axes) { diff --git a/src/tirx/ir/buffer.cc b/src/tirx/ir/buffer.cc index 9de83733372a..3fc57429a25a 100644 --- a/src/tirx/ir/buffer.cc +++ b/src/tirx/ir/buffer.cc @@ -44,7 +44,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { BufferNode::RegisterReflection(); } using IndexMod = tirx::FloorModNode; using IndexDiv = tirx::FloorDivNode; -ffi::Array SimplifyArray(arith::Analyzer* ana, ffi::Array array) { +ffi::Array SimplifyArray(arith::AnalyzerObj* ana, ffi::Array array) { for (size_t i = 0; i < array.size(); ++i) { array.Set(i, ana->Simplify(array[i])); } @@ -89,7 +89,7 @@ inline std::vector ExprSplitAddition(const PrimExpr& expr) { // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1) // Currently the we will not search the add/mult combinations exhaustively // as it will take too much computation. -inline std::pair MergeMulModInner(arith::Analyzer* analyzer, +inline std::pair MergeMulModInner(arith::AnalyzerObj* analyzer, const PrimExpr& mult_expr, const PrimExpr& mod_l_expr, const PrimExpr& mod_r_expr) { @@ -186,7 +186,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized -inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { +inline PrimExpr MergeMulMod(arith::AnalyzerObj* analyzer, const PrimExpr& base) { using namespace tirx; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and @@ -306,7 +306,7 @@ ffi::Array BufferNode::ElemOffset(ffi::Array input_indices, } if (i > 0) { - output_index = MergeMulMod(&ana, output_index); + output_index = MergeMulMod(ana.get(), output_index); } output_indices.Set(current_output_axis, output_index); @@ -318,7 +318,7 @@ ffi::Array BufferNode::ElemOffset(ffi::Array input_indices, } } - return SimplifyArray(&ana, output_indices); + return SimplifyArray(ana.get(), output_indices); } inline ffi::Array BufferOffset(const BufferNode* n, ffi::Array index, @@ -499,9 +499,9 @@ Buffer Buffer::MakeSlice(ffi::Array begins, ffi::Array exten const BufferNode* n = operator->(); TVM_FFI_ICHECK(n != nullptr); arith::Analyzer ana; - begins = SimplifyArray(&ana, begins); + begins = SimplifyArray(ana.get(), begins); ffi::Array elem_offset = - n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana.Simplify(expr); }); + n->ElemOffset(begins).Map([&](const PrimExpr& expr) { return ana->Simplify(expr); }); ffi::Array strides = n->strides; if (strides.size() == 0) { @@ -510,7 +510,7 @@ Buffer Buffer::MakeSlice(ffi::Array begins, ffi::Array exten // check if stride is needed. for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { - if (!is_zero(begins[i]) || !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { + if (!is_zero(begins[i]) || !is_zero(ana->Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } diff --git a/src/tirx/ir/exec_scope.cc b/src/tirx/ir/exec_scope.cc index 7c3bda5995f4..c885c0251134 100644 --- a/src/tirx/ir/exec_scope.cc +++ b/src/tirx/ir/exec_scope.cc @@ -212,7 +212,7 @@ bool ScopeIdDefVerifier::Verify(const ffi::Array& defs, Mode mode) { it->second = upgraded; queue.push(upgraded); } else if (existing_known && new_known) { - TVM_FFI_ICHECK(ana.CanProveEqual(existing.fused_extent(), id.fused_extent())) + TVM_FFI_ICHECK(ana->CanProveEqual(existing.fused_extent(), id.fused_extent())) << "Inconsistent extents for scope binding " << static_cast(id->scope); } // else: existing wins (known beats unknown; both unknown is a no-op). @@ -316,11 +316,11 @@ static ffi::Optional Compliment(const ScopeIdDef& lhs, const ScopeId arith::Analyzer ana; auto try_compliment = [&](PrimExpr lhs_ext, PrimExpr rhs_ext, ScopeBinding scope) -> ffi::Optional { - if (ana.CanProve(floormod(lhs_ext, rhs_ext) == 0)) { + if (ana->CanProve(floormod(lhs_ext, rhs_ext) == 0)) { return ScopeIdDef(ffi::Array{Var("")}, ffi::Array{floordiv(lhs_ext, rhs_ext)}, scope); } - TVM_FFI_ICHECK(!ana.CanProve(floormod(lhs_ext, rhs_ext) != 0)) + TVM_FFI_ICHECK(!ana->CanProve(floormod(lhs_ext, rhs_ext) != 0)) << "ValueError: scope binding " << static_cast(scope) << " has non-divisible extents: " << lhs_ext << " is not divisible by " << rhs_ext; return std::nullopt; @@ -394,23 +394,23 @@ ffi::Array ResolveCuda(ScopeBinding binding, } case ScopeBinding::kCtaWarpgroup: { TVM_FFI_ICHECK_EQ(out_dim, 1) << "ValueError: cta->warpgroup must be 1D"; - return {ana.Simplify(FloorDiv(GetThread("warp_id_in_cta", params).first, 4))}; + return {ana->Simplify(FloorDiv(GetThread("warp_id_in_cta", params).first, 4))}; } case ScopeBinding::kCtaWarp: { TVM_FFI_ICHECK_EQ(out_dim, 1) << "ValueError: cta->warp must be 1D"; - return {ana.Simplify(GetThread("warp_id_in_cta", params).first)}; + return {ana->Simplify(GetThread("warp_id_in_cta", params).first)}; } case ScopeBinding::kWarpgroupWarp: { TVM_FFI_ICHECK_EQ(out_dim, 1) << "ValueError: warpgroup->warp must be 1D"; - return {ana.Simplify(FloorMod(GetThread("warp_id_in_cta", params).first, 4))}; + return {ana->Simplify(FloorMod(GetThread("warp_id_in_cta", params).first, 4))}; } case ScopeBinding::kWarpgroupThread: { TVM_FFI_ICHECK_EQ(out_dim, 1) << "ValueError: warpgroup->thread must be 1D"; - return {ana.Simplify(FloorMod(GetLinearThreadIndex(params), 128))}; + return {ana->Simplify(FloorMod(GetLinearThreadIndex(params), 128))}; } case ScopeBinding::kWarpThread: { TVM_FFI_ICHECK_EQ(out_dim, 1) << "ValueError: warp->thread must be 1D"; - return {ana.Simplify(FloorMod(GetLinearThreadIndex(params), 32))}; + return {ana->Simplify(FloorMod(GetLinearThreadIndex(params), 32))}; } case ScopeBinding::kClusterCtaPair: { TVM_FFI_ICHECK_EQ(out_dim, 1) << "ValueError: cluster->cta_pair must be 1D"; @@ -418,7 +418,7 @@ ffi::Array ResolveCuda(ScopeBinding binding, std::tie(cbx, ex) = GetThread("clusterCtaIdx.x", params, true); std::tie(cby, ey) = GetThread("clusterCtaIdx.y", params, true); std::tie(cbz, ez) = GetThread("clusterCtaIdx.z", params, true); - return {ana.Simplify(FloorMod(cbx + cby * ex + cbz * ex * ey, 2))}; + return {ana->Simplify(FloorMod(cbx + cby * ex + cbz * ex * ey, 2))}; } } LOG(FATAL) << "Internal Error: unknown ScopeBinding " << static_cast(binding); diff --git a/src/tirx/ir/index_map.cc b/src/tirx/ir/index_map.cc index 43b379351a70..1e27503c082e 100644 --- a/src/tirx/ir/index_map.cc +++ b/src/tirx/ir/index_map.cc @@ -61,8 +61,9 @@ IndexMap IndexMap::FromFunc(int ndim, std::pair IndexMapInverseImpl(const IndexMap& self, const ffi::Array& initial_ranges, arith::IterMapLevel check_level, - arith::Analyzer* analyzer) { + arith::AnalyzerObj* analyzer) { TVM_FFI_ICHECK(analyzer != nullptr); + arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); if (self->inverse_index_map.defined()) { // return the pre-defined inverse index map if exists. In this // case, the user-defined inverse is assumed to be correct and @@ -96,7 +97,7 @@ std::pair IndexMapInverseImpl(const IndexMap& self, // Unpack the output indices into linear combinations of the initial // indices. auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /*predicate=*/1, - /*check_level=*/check_level, analyzer, + /*check_level=*/check_level, analyzer_ref, /*simplify_trivial_iterators=*/false); TVM_FFI_ICHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " @@ -124,41 +125,57 @@ std::pair IndexMapInverseImpl(const IndexMap& self, padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); padding_predicate = Substitute(padding_predicate, inverse_exprs_map); - auto output_ranges = self->MapRanges(initial_ranges, analyzer); + auto output_ranges = self->MapRanges(initial_ranges, analyzer_ref); { TVM_FFI_ICHECK_EQ(output_ranges.size(), output_vars.size()); arith::Analyzer analyzer; for (size_t i = 0; i < output_vars.size(); ++i) { - analyzer.Bind(output_vars[i], output_ranges[i]); + analyzer->Bind(output_vars[i], output_ranges[i]); } // Additional simplification steps required to unwrap nested floordiv/floormod - padding_predicate = analyzer.Simplify(padding_predicate, 10); + padding_predicate = analyzer->Simplify(padding_predicate, 10); } return {IndexMap(output_vars, inverse_exprs), padding_predicate}; } -std::pair IndexMap::NonSurjectiveInverse(ffi::Array initial_ranges, - arith::Analyzer* analyzer) const { - TVM_FFI_ICHECK(analyzer != nullptr); - return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck, analyzer); +std::pair IndexMap::NonSurjectiveInverse( + ffi::Array initial_ranges) const { + arith::Analyzer analyzer; + return NonSurjectiveInverse(initial_ranges, analyzer); } -IndexMap IndexMap::Inverse(ffi::Array initial_ranges, arith::Analyzer* analyzer) const { - TVM_FFI_ICHECK(analyzer != nullptr); +std::pair IndexMap::NonSurjectiveInverse( + ffi::Array initial_ranges, const arith::Analyzer& analyzer) const { + return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck, analyzer.get()); +} + +IndexMap IndexMap::Inverse(ffi::Array initial_ranges) const { + arith::Analyzer analyzer; + return Inverse(initial_ranges, analyzer); +} + +IndexMap IndexMap::Inverse(ffi::Array initial_ranges, + const arith::Analyzer& analyzer) const { + arith::AnalyzerObj* analyzer_ptr = analyzer.get(); auto [inverse, padding_predicate] = - IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective, analyzer); - TVM_FFI_ICHECK(analyzer->CanProve(!padding_predicate)) + IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective, analyzer_ptr); + TVM_FFI_ICHECK(analyzer_ptr->CanProve(!padding_predicate)) << "Bijective inverse should not contain padding, but inverse of " << *this << " over range " << initial_ranges << " resulted in a padding predicate of " << padding_predicate; return inverse; } +ffi::Array IndexMapNode::MapIndices(const ffi::Array& indices) const { + arith::Analyzer analyzer; + return MapIndices(indices, analyzer); +} + ffi::Array IndexMapNode::MapIndices(const ffi::Array& indices, - arith::Analyzer* analyzer) const { - TVM_FFI_ICHECK(analyzer != nullptr); + const arith::Analyzer& analyzer) const { + arith::AnalyzerObj* analyzer_ptr = analyzer.get(); TVM_FFI_ICHECK_EQ(indices.size(), initial_indices.size()); ffi::Map vmap; @@ -170,14 +187,19 @@ ffi::Array IndexMapNode::MapIndices(const ffi::Array& indice ffi::Array output = final_indices.Map([&](PrimExpr index) { PrimExpr result = SubstituteWithDataTypeLegalization( std::move(index), [&](const Var& var) { return vmap.Get(var); }); - return analyzer->Simplify(result); + return analyzer_ptr->Simplify(result); }); return output; } +ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges) const { + arith::Analyzer analyzer; + return MapRanges(ranges, analyzer); +} + ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, - arith::Analyzer* analyzer) const { - TVM_FFI_ICHECK(analyzer != nullptr); + const arith::Analyzer& analyzer) const { + arith::AnalyzerObj* analyzer_ptr = analyzer.get(); TVM_FFI_ICHECK_EQ(ranges.size(), initial_indices.size()); ffi::Map input_iters; @@ -217,8 +239,9 @@ ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, for (const auto& final_index : final_indices) { auto int_set = arith::EvalSet(final_index, dom_map); - output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()), - analyzer->Simplify(int_set.max() - int_set.min() + 1))); + output.push_back( + Range::FromMinExtent(analyzer_ptr->Simplify(int_set.min()), + analyzer_ptr->Simplify(int_set.max() - int_set.min() + 1))); } } auto output_dtype = [&]() { @@ -239,9 +262,13 @@ ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, return output; } +ffi::Array IndexMapNode::MapShape(const ffi::Array& shape) const { + arith::Analyzer analyzer; + return MapShape(shape, analyzer); +} + ffi::Array IndexMapNode::MapShape(const ffi::Array& shape, - arith::Analyzer* analyzer) const { - TVM_FFI_ICHECK(analyzer != nullptr); + const arith::Analyzer& analyzer) const { TVM_FFI_ICHECK_EQ(shape.size(), initial_indices.size()); ffi::Array ranges; @@ -271,7 +298,7 @@ runtime::Tensor IndexMapNode::MapTensor(runtime::Tensor arr_src) const { size_1d *= shape[i]; orig_shape.push_back(PrimExpr(static_cast((shape[i])))); } - auto dst_shape = MapShape(orig_shape, &analyzer); + auto dst_shape = MapShape(orig_shape, analyzer); std::vector dst_shape_int; for (size_t i = 0; i < dst_shape.size(); ++i) { @@ -295,7 +322,7 @@ runtime::Tensor IndexMapNode::MapTensor(runtime::Tensor arr_src) const { src_indices.push_back(PrimExpr(static_cast((src_linear_index / div_factor)))); src_linear_index %= div_factor; } - auto dst_indices = MapIndices(src_indices, &analyzer); + auto dst_indices = MapIndices(src_indices, analyzer); // Convert an N-d coordinate to a linear coordinate // (z, y, x) -> z * height * width + y * width + x @@ -434,26 +461,34 @@ TVM_FFI_STATIC_INIT_BLOCK() { return IndexMap(initial_indices, final_indices, inverse_index_map); }) .def("tirx.IndexMapMapIndices", - [](IndexMap map, ffi::Array indices) { - arith::Analyzer analyzer; - return map->MapIndices(indices, &analyzer); + [](IndexMap map, ffi::Array indices, + ffi::Optional opt_analyzer) { + arith::Analyzer analyzer = + opt_analyzer.has_value() ? opt_analyzer.value() : arith::Analyzer(); + return map->MapIndices(indices, analyzer); }) .def("tirx.IndexMapMapShape", - [](IndexMap map, ffi::Array shape) { - arith::Analyzer analyzer; - return map->MapShape(shape, &analyzer); + [](IndexMap map, ffi::Array shape, + ffi::Optional opt_analyzer) { + arith::Analyzer analyzer = + opt_analyzer.has_value() ? opt_analyzer.value() : arith::Analyzer(); + return map->MapShape(shape, analyzer); }) .def("tirx.IndexMapInverse", - [](IndexMap map, ffi::Array initial_ranges) { - arith::Analyzer analyzer; - return map.Inverse(initial_ranges, &analyzer); + [](IndexMap map, ffi::Array initial_ranges, + ffi::Optional opt_analyzer) { + arith::Analyzer analyzer = + opt_analyzer.has_value() ? opt_analyzer.value() : arith::Analyzer(); + return map.Inverse(initial_ranges, analyzer); }) .def("tirx.IndexMapMapTensor", [](IndexMap map, runtime::Tensor arr) { return map->MapTensor(arr); }) .def("tirx.IndexMapNonSurjectiveInverse", - [](IndexMap forward, ffi::Array initial_ranges) { - arith::Analyzer analyzer; - auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); + [](IndexMap forward, ffi::Array initial_ranges, + ffi::Optional opt_analyzer) { + arith::Analyzer analyzer = + opt_analyzer.has_value() ? opt_analyzer.value() : arith::Analyzer(); + auto result = forward.NonSurjectiveInverse(initial_ranges, analyzer); return ffi::Array{result.first, result.second}; }); } diff --git a/src/tirx/ir/layout/axis_registry.cc b/src/tirx/ir/layout/axis_registry.cc index 942e69bfd0e8..2afd290037c8 100644 --- a/src/tirx/ir/layout/axis_registry.cc +++ b/src/tirx/ir/layout/axis_registry.cc @@ -163,15 +163,15 @@ void AxisRegEntry::UpdateAttr(const ffi::String& key, ffi::Any value, int plevel ffi::Array SplitterGen(const Iter& iter, const Axis& axis_outer, const Axis& axis_inner, const PrimExpr& e_inner) { arith::Analyzer analyzer; - if (analyzer.CanProve(iter->extent * iter->stride < e_inner)) { + if (analyzer->CanProve(iter->extent * iter->stride < e_inner)) { return {Iter(iter->extent, iter->stride, axis_inner)}; - } else if (analyzer.CanProveEqual(floormod(e_inner, iter->stride), 0) && - analyzer.CanProveEqual(floormod(iter->extent * iter->stride, e_inner), 0)) { - const auto& d = analyzer.Simplify(floordiv(e_inner, iter->stride)); - const auto& c = analyzer.Simplify(floordiv(iter->extent, d)); + } else if (analyzer->CanProveEqual(floormod(e_inner, iter->stride), 0) && + analyzer->CanProveEqual(floormod(iter->extent * iter->stride, e_inner), 0)) { + const auto& d = analyzer->Simplify(floordiv(e_inner, iter->stride)); + const auto& c = analyzer->Simplify(floordiv(iter->extent, d)); return {Iter(c, IntImm(e_inner.dtype(), 1), axis_outer), Iter(d, iter->stride, axis_inner)}; - } else if (analyzer.CanProveEqual(floormod(iter->stride, e_inner), 0)) { - const auto& d = analyzer.Simplify(floordiv(iter->stride, e_inner)); + } else if (analyzer->CanProveEqual(floormod(iter->stride, e_inner), 0)) { + const auto& d = analyzer->Simplify(floordiv(iter->stride, e_inner)); return {Iter(iter->extent, d, axis_outer)}; } return {}; diff --git a/src/tirx/ir/layout/swizzle_layout.cc b/src/tirx/ir/layout/swizzle_layout.cc index 59f31199283b..aa80223085b0 100644 --- a/src/tirx/ir/layout/swizzle_layout.cc +++ b/src/tirx/ir/layout/swizzle_layout.cc @@ -80,7 +80,7 @@ ffi::Map SwizzleLayoutNode::Apply(PrimExpr coord) const { // It takes more arithmetic operations to compute the result, but it is more friendly to the // vectorization. We use "m" as the default axis name here. return { - {"m", analyzer.Simplify((f(floordiv(input, base)) << per_element) + floormod(input, base))}}; + {"m", analyzer->Simplify((f(floordiv(input, base)) << per_element) + floormod(input, base))}}; } Layout SwizzleLayoutNode::Canonicalize() const { return ffi::GetRef(this); } diff --git a/src/tirx/ir/layout/tile_canonicalize.cc b/src/tirx/ir/layout/tile_canonicalize.cc index 834a42afbf8e..603e1f18e931 100644 --- a/src/tirx/ir/layout/tile_canonicalize.cc +++ b/src/tirx/ir/layout/tile_canonicalize.cc @@ -62,7 +62,7 @@ TileLayout FuseContiguousShardIters(TileLayout layout) { PrimExpr extent = shard[cur]->extent; size_t next = cur + 1; while (next < shard.size() && shard[next]->axis.same_as(shard[cur]->axis) && - ana.CanProveEqual(shard[next]->extent * shard[next]->stride, shard[next - 1]->stride)) { + ana->CanProveEqual(shard[next]->extent * shard[next]->stride, shard[next - 1]->stride)) { extent *= shard[next]->extent; ++next; } diff --git a/src/tirx/ir/layout/tile_core.cc b/src/tirx/ir/layout/tile_core.cc index 19b6b5f4b986..979f95c21005 100644 --- a/src/tirx/ir/layout/tile_core.cc +++ b/src/tirx/ir/layout/tile_core.cc @@ -75,7 +75,7 @@ bool VerifyCompactness(const std::vector& iters) { PrimExpr stride_to_find = 1; for (size_t i = 0; i < iters.size(); ++i) { auto iter = std::find_if(iters.begin(), iters.end(), [&](const Iter& iter) { - return analyzer.CanProveEqual(iter->stride, stride_to_find); + return analyzer->CanProveEqual(iter->stride, stride_to_find); }); if (iter == iters.end()) return false; stride_to_find *= (*iter)->extent; @@ -140,7 +140,7 @@ PrimExpr TileLayoutNode::GetSpan(ffi::Optional axis_name) const { for (const auto& [axis, off] : offset) { if (filter(axis)) result += off; } - return analyzer.Simplify(result); + return analyzer->Simplify(result); } ffi::Map TileLayoutNode::Apply(PrimExpr coord) const { @@ -192,18 +192,18 @@ ffi::Map TileLayoutNode::Apply(Array coord) con for (size_t i = 0; i < shard.size(); ++i) { auto it = result.find(shard[i]->axis->name); if (it == result.end()) { - result[shard[i]->axis->name] = analyzer.Simplify(coord[i] * shard[i]->stride); + result[shard[i]->axis->name] = analyzer->Simplify(coord[i] * shard[i]->stride); } else { - result[shard[i]->axis->name] = analyzer.Simplify(it->second + coord[i] * shard[i]->stride); + result[shard[i]->axis->name] = analyzer->Simplify(it->second + coord[i] * shard[i]->stride); } } // Add offset to the result for (const auto& [axis, off] : offset) { auto it = result.find(axis->name); if (it == result.end()) { - result[axis->name] = analyzer.Simplify(off); + result[axis->name] = analyzer->Simplify(off); } else { - result[axis->name] = analyzer.Simplify(it->second + off); + result[axis->name] = analyzer->Simplify(it->second + off); } } return result; diff --git a/src/tirx/ir/layout/tile_direct_sum_ops.cc b/src/tirx/ir/layout/tile_direct_sum_ops.cc index 481b3bd80ee2..33622453dc20 100644 --- a/src/tirx/ir/layout/tile_direct_sum_ops.cc +++ b/src/tirx/ir/layout/tile_direct_sum_ops.cc @@ -61,7 +61,7 @@ Layout TileLayoutNode::DirectSum(const TileLayout& left_in, const Arrayoffset) { auto it = sum_off.find(axis); if (it != sum_off.end()) { - sum_off.Set(axis, analyzer.Simplify((*it).second + off)); + sum_off.Set(axis, analyzer->Simplify((*it).second + off)); } else { sum_off.Set(axis, off); } @@ -70,7 +70,7 @@ Layout TileLayoutNode::DirectSum(const TileLayout& left_in, const ArrayCanonicalize(); } -static bool IterEqualRelaxUnit(const Iter& a, const Iter& b, arith::Analyzer* analyzer) { +static bool IterEqualRelaxUnit(const Iter& a, const Iter& b, arith::AnalyzerObj* analyzer) { if (!(*analyzer).CanProveEqual(a->extent, b->extent)) return false; if (!is_one(a->extent)) { if (!(*analyzer).CanProveEqual(a->stride, b->stride)) return false; @@ -88,9 +88,9 @@ static ffi::Map SubtractOffsets(const ffi::Map& for (const auto& [axis, off] : rhs) { auto it = res.find(axis); if (it != res.end()) { - res.Set(axis, analyzer.Simplify((*it).second - off)); + res.Set(axis, analyzer->Simplify((*it).second - off)); } else { - res.Set(axis, analyzer.Simplify(-off)); + res.Set(axis, analyzer->Simplify(-off)); } } return res; @@ -128,7 +128,7 @@ ffi::Optional TileLayoutNode::IsDirectSumRight( for (int j = 0; j < right_cnt; ++j) { Iter s_iter = grouped_sum->shard[sum_seps[2 * i + 2] - right_cnt + j]; Iter r_iter = grouped_right->shard[right_seps[i] + j]; - if (!IterEqualRelaxUnit(s_iter, r_iter, &analyzer)) return std::nullopt; + if (!IterEqualRelaxUnit(s_iter, r_iter, analyzer.get())) return std::nullopt; } // If sum_right_cnt > right_cnt, residual dims cannot be attributed; reject for now. if (sum_right_cnt != right_cnt) return std::nullopt; @@ -175,7 +175,7 @@ ffi::Optional TileLayoutNode::IsDirectSumLeft( for (int j = 0; j < left_cnt; ++j) { Iter s_iter = grouped_sum->shard[sum_seps[2 * i] + j]; Iter l_iter = grouped_left->shard[left_seps[i] + j]; - if (!IterEqualRelaxUnit(s_iter, l_iter, &analyzer)) return std::nullopt; + if (!IterEqualRelaxUnit(s_iter, l_iter, analyzer.get())) return std::nullopt; } // If sum_left_cnt > left_cnt, residual dims cannot be attributed; reject for now. if (sum_left_cnt != left_cnt) return std::nullopt; diff --git a/src/tirx/ir/layout/tile_slice.cc b/src/tirx/ir/layout/tile_slice.cc index 5d8762e0d4cf..3f4db4837964 100644 --- a/src/tirx/ir/layout/tile_slice.cc +++ b/src/tirx/ir/layout/tile_slice.cc @@ -40,7 +40,7 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE PrimExpr acc = PrimExpr(1); for (int k = m - 1; k >= 0; --k) { B[k] = acc; - acc = analyzer.Simplify(acc * shard[k]->extent); + acc = analyzer->Simplify(acc * shard[k]->extent); } std::vector d0(m); @@ -50,9 +50,9 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE auto add_axis_offset = [&](const Axis& axis, PrimExpr value) { auto it = new_offset.find(axis); if (it != new_offset.end()) { - new_offset.Set(axis, analyzer.Simplify((*it).second + value)); + new_offset.Set(axis, analyzer->Simplify((*it).second + value)); } else { - new_offset.Set(axis, analyzer.Simplify(value)); + new_offset.Set(axis, analyzer->Simplify(value)); } }; @@ -71,12 +71,12 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE // loop). Skip the mod when ``m == 1`` and rely on the contract. PrimExpr dk0; if (m == 1) { - dk0 = analyzer.Simplify(floordiv(begin, B[k])); + dk0 = analyzer->Simplify(floordiv(begin, B[k])); } else { - dk0 = analyzer.Simplify(floormod(floordiv(begin, B[k]), Ek)); + dk0 = analyzer->Simplify(floormod(floordiv(begin, B[k]), Ek)); } d0[k] = dk0; - add_axis_offset(ak, analyzer.Simplify(dk0 * Sk)); + add_axis_offset(ak, analyzer->Simplify(dk0 * Sk)); } // Special case: @@ -95,14 +95,14 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE for (; pivot >= 0; --pivot) { const PrimExpr& Ek = shard[pivot]->extent; bool peelable = - analyzer.CanProveEqual(d0[pivot], 0) && analyzer.CanProveEqual(floormod(rem, Ek), 0); + analyzer->CanProveEqual(d0[pivot], 0) && analyzer->CanProveEqual(floormod(rem, Ek), 0); if (!peelable) break; peeled_rev.push_back(shard[pivot]); - rem = analyzer.Simplify(floordiv(rem, Ek)); + rem = analyzer->Simplify(floordiv(rem, Ek)); } if (pivot < 0) { - if (!analyzer.CanProveEqual(rem, 1)) return std::nullopt; + if (!analyzer->CanProveEqual(rem, 1)) return std::nullopt; std::vector peeled_slow_to_fast(peeled_rev.rbegin(), peeled_rev.rend()); return TileLayout(peeled_slow_to_fast, layout->replica, new_offset); } @@ -111,7 +111,7 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE const PrimExpr& Sk = shard[pivot]->stride; const Axis& ak = shard[pivot]->axis; - if (analyzer.CanProve(d0[pivot] + rem <= Ek)) { + if (analyzer->CanProve(d0[pivot] + rem <= Ek)) { std::vector new_shard; new_shard.push_back(Iter(rem, Sk, ak)); new_shard.insert(new_shard.end(), peeled_rev.rbegin(), peeled_rev.rend()); @@ -119,17 +119,17 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE } PrimExpr two = make_const(rem.dtype(), 2); - PrimExpr c = analyzer.Simplify(floordiv(rem, two)); - bool even = analyzer.CanProveEqual(floormod(rem, two), 0); - bool mid = analyzer.CanProveEqual(analyzer.Simplify(d0[pivot] + c), Ek); + PrimExpr c = analyzer->Simplify(floordiv(rem, two)); + bool even = analyzer->CanProveEqual(floormod(rem, two), 0); + bool mid = analyzer->CanProveEqual(analyzer->Simplify(d0[pivot] + c), Ek); bool cap = true; if (pivot > 0) { - cap = analyzer.CanProve(analyzer.Simplify(d0[pivot - 1] + 1 <= shard[pivot - 1]->extent)); + cap = analyzer->CanProve(analyzer->Simplify(d0[pivot - 1] + 1 <= shard[pivot - 1]->extent)); } if (even && mid && cap) { if (pivot == 0 || shard[pivot - 1]->axis.same_as(ak)) { PrimExpr delta = - analyzer.Simplify((pivot > 0 ? shard[pivot - 1]->stride : PrimExpr(0)) - (Ek - c) * Sk); + analyzer->Simplify((pivot > 0 ? shard[pivot - 1]->stride : PrimExpr(0)) - (Ek - c) * Sk); std::vector new_shard; new_shard.push_back(Iter(make_const(c.dtype(), 2), delta, ak)); new_shard.push_back(Iter(c, Sk, ak)); @@ -151,16 +151,16 @@ ffi::Optional TileLayoutNode::Slice(const Array& shape, std::vector shard(grouped_layout->shard.begin() + seps[i], grouped_layout->shard.begin() + seps[i + 1]); TileLayout group = TileLayout(shard, {}, {}); - auto sliced_opt = SlicePerGroup(group, region[i]->min, analyzer.Simplify(region[i]->extent)); + auto sliced_opt = SlicePerGroup(group, region[i]->min, analyzer->Simplify(region[i]->extent)); if (!sliced_opt.has_value()) return std::nullopt; auto sliced = sliced_opt.value(); new_shard.insert(new_shard.end(), sliced->shard.begin(), sliced->shard.end()); for (const auto& [axis, off] : sliced->offset) { auto it = new_offset.find(axis); if (it != new_offset.end()) { - new_offset.Set(axis, analyzer.Simplify((*it).second + off)); + new_offset.Set(axis, analyzer->Simplify((*it).second + off)); } else { - new_offset.Set(axis, analyzer.Simplify(off)); + new_offset.Set(axis, analyzer->Simplify(off)); } } } diff --git a/src/tirx/ir/layout/tile_tile_ops.cc b/src/tirx/ir/layout/tile_tile_ops.cc index 7ab4cb0131fe..e6cabdd98aba 100644 --- a/src/tirx/ir/layout/tile_tile_ops.cc +++ b/src/tirx/ir/layout/tile_tile_ops.cc @@ -39,27 +39,27 @@ std::pair> Group(TileLayout layout, auto stride_i = layout->shard[i]->stride; prod *= extent_i; while (shape_idx < shape.size() && - analyzer.CanProveEqual(floormod(prod, shape[shape_idx]), 0)) { + analyzer->CanProveEqual(floormod(prod, shape[shape_idx]), 0)) { // Simplify ``c``, ``floordiv(extent_i, c)`` and ``stride_i * c`` — // without this, splitting an iter whose extent contains a symbolic // dim that algebraically cancels (e.g. ``floordiv(batch_size, // batch_size) == 1``) leaves dead ``a // a`` factors in the new // iter's stride that ``int(stride)`` can't unwrap downstream. - PrimExpr c = analyzer.Simplify(floordiv(prod, shape[shape_idx])); - TVM_FFI_ICHECK(analyzer.CanProveEqual(floormod(extent_i, c), 0)) + PrimExpr c = analyzer->Simplify(floordiv(prod, shape[shape_idx])); + TVM_FFI_ICHECK(analyzer->CanProveEqual(floormod(extent_i, c), 0)) << "layout " << layout << " can not be grouped by shape " << shape; - new_shard.push_back(Iter(analyzer.Simplify(floordiv(extent_i, c)), - analyzer.Simplify(stride_i * c), layout->shard[i]->axis)); + new_shard.push_back(Iter(analyzer->Simplify(floordiv(extent_i, c)), + analyzer->Simplify(stride_i * c), layout->shard[i]->axis)); extent_i = c; prod = c; shape_idx++; seps.push_back(new_shard.size()); } - extent_i = analyzer.Simplify(extent_i); + extent_i = analyzer->Simplify(extent_i); if (!is_one(extent_i)) { TVM_FFI_ICHECK(shape_idx < shape.size()) << "layout " << layout << " can not be grouped by shape " << shape; - new_shard.push_back(Iter(extent_i, analyzer.Simplify(stride_i), layout->shard[i]->axis)); + new_shard.push_back(Iter(extent_i, analyzer->Simplify(stride_i), layout->shard[i]->axis)); } } @@ -88,20 +88,20 @@ std::optional>> TryGroup( auto stride_i = layout->shard[i]->stride; prod *= extent_i; while (shape_idx < shape.size() && - analyzer.CanProveEqual(floormod(prod, shape[shape_idx]), 0)) { - PrimExpr c = analyzer.Simplify(floordiv(prod, shape[shape_idx])); - if (!analyzer.CanProveEqual(floormod(extent_i, c), 0)) return std::nullopt; - new_shard.push_back(Iter(analyzer.Simplify(floordiv(extent_i, c)), - analyzer.Simplify(stride_i * c), layout->shard[i]->axis)); + analyzer->CanProveEqual(floormod(prod, shape[shape_idx]), 0)) { + PrimExpr c = analyzer->Simplify(floordiv(prod, shape[shape_idx])); + if (!analyzer->CanProveEqual(floormod(extent_i, c), 0)) return std::nullopt; + new_shard.push_back(Iter(analyzer->Simplify(floordiv(extent_i, c)), + analyzer->Simplify(stride_i * c), layout->shard[i]->axis)); extent_i = c; prod = c; shape_idx++; seps.push_back(new_shard.size()); } - extent_i = analyzer.Simplify(extent_i); + extent_i = analyzer->Simplify(extent_i); if (!is_one(extent_i)) { if (shape_idx >= shape.size()) return std::nullopt; - new_shard.push_back(Iter(extent_i, analyzer.Simplify(stride_i), layout->shard[i]->axis)); + new_shard.push_back(Iter(extent_i, analyzer->Simplify(stride_i), layout->shard[i]->axis)); } } @@ -194,7 +194,7 @@ ffi::Array TileShape(ffi::Array shape, ffi::Array ffi::Array new_shape; for (int i = 0; i < static_cast(shape.size()); ++i) { - TVM_FFI_ICHECK(analyzer.CanProveEqual(floormod(shape[i], factor[i]), 0)) + TVM_FFI_ICHECK(analyzer->CanProveEqual(floormod(shape[i], factor[i]), 0)) << "Shape[i] must be divisible by factor[i]"; if (is_inner) { @@ -294,7 +294,7 @@ ffi::Optional TileLayoutNode::IsTileInner( auto rescale_by_inner_span = [&](const Iter& iter) -> ffi::Optional { auto it = inner_span_map.find(iter->axis->name); if (it != inner_span_map.end() && !is_one(iter->extent)) { - if (!analyzer.CanProveEqual(floormod(iter->stride, (*it).second), 0)) { + if (!analyzer->CanProveEqual(floormod(iter->stride, (*it).second), 0)) { return std::nullopt; } return Iter(iter->extent, floordiv(iter->stride, (*it).second), iter->axis); @@ -327,9 +327,9 @@ ffi::Optional TileLayoutNode::IsTileInner( for (int j = 0; j < inner_count; ++j) { Iter inner_iter = grouped_layout->shard[inner_seps[i] + j]; Iter tiled_iter = grouped_tiled->shard[tiled_seps_even[i + 1] - inner_count + j]; - if (!analyzer.CanProveEqual(inner_iter->extent, tiled_iter->extent) || + if (!analyzer->CanProveEqual(inner_iter->extent, tiled_iter->extent) || (!is_one(inner_iter->extent) && - !(analyzer.CanProveEqual(inner_iter->stride, tiled_iter->stride) && + !(analyzer->CanProveEqual(inner_iter->stride, tiled_iter->stride) && inner_iter->axis.same_as(tiled_iter->axis)))) { return std::nullopt; } @@ -357,7 +357,7 @@ ffi::Optional TileLayoutNode::IsTileInner( for (const auto& [axis, off] : tiled->offset) { auto it = layout->offset.find(axis); if (it != layout->offset.end()) { - outer_exclude.Set(axis, analyzer.Simplify(off - (*it).second)); + outer_exclude.Set(axis, analyzer->Simplify(off - (*it).second)); } else { outer_exclude.Set(axis, off); } @@ -416,7 +416,7 @@ ffi::Optional TileLayoutNode::IsTileOuter(const Layout& tile_layout, for (int j = 0; j < outer_count; ++j) { Iter outer_iter = grouped_layout->shard[outer_seps[i] + j]; Iter tiled_iter = grouped_tiled->shard[tiled_seps_even[i] + j]; - if (!analyzer.CanProveEqual(outer_iter->extent, tiled_iter->extent) || + if (!analyzer->CanProveEqual(outer_iter->extent, tiled_iter->extent) || (!is_one(outer_iter->extent) && !outer_iter->axis.same_as(tiled_iter->axis))) { return std::nullopt; } @@ -440,7 +440,7 @@ ffi::Optional TileLayoutNode::IsTileOuter(const Layout& tile_layout, for (const auto& [axis, off] : tiled->offset) { auto it = layout->offset.find(axis); if (it != layout->offset.end()) { - inner_exclude.Set(axis, analyzer.Simplify(off - (*it).second)); + inner_exclude.Set(axis, analyzer->Simplify(off - (*it).second)); } else { inner_exclude.Set(axis, off); } diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc index b0c0f6d037d3..146233589a30 100644 --- a/src/tirx/ir/stmt.cc +++ b/src/tirx/ir/stmt.cc @@ -556,14 +556,14 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { << source->region.size() << " vs. " << buffer->shape.size(); size_t offset = source->region.size() - buffer->shape.size(); for (size_t i = 0; i < offset; ++i) { - TVM_FFI_ICHECK(analyzer.CanProve(source->region[i]->extent == 1)) + TVM_FFI_ICHECK(analyzer->CanProve(source->region[i]->extent == 1)) << "The higher dimension should be 1, but got " << source->region[i]->extent << "."; } for (size_t i = 0; i < buffer->shape.size(); ++i) { const Range& source_range = source->region[i + offset]; const PrimExpr& buffer_shape = buffer->shape[i]; if (!buffer_shape->IsInstance()) { - TVM_FFI_ICHECK(analyzer.CanProve(source_range->extent == buffer_shape)) + TVM_FFI_ICHECK(analyzer->CanProve(source_range->extent == buffer_shape)) << "The dimension mismatched between source region and target buffer shape, got " << source_range->extent << " vs. " << buffer_shape << "."; } diff --git a/src/tirx/script/builder/ir.cc b/src/tirx/script/builder/ir.cc index 500ac254e1cc..e3daba408cd4 100644 --- a/src/tirx/script/builder/ir.cc +++ b/src/tirx/script/builder/ir.cc @@ -507,7 +507,7 @@ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType ffi::Optional> annotations, \ ffi::Optional step) { \ PrimExpr min = start; \ - PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + PrimExpr extent = arith::Analyzer()->Simplify(stop - start); \ ffi::ObjectPtr n = ffi::make_object(); \ int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ @@ -536,7 +536,7 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, ffi::Optional> annotations) { using namespace tvm::tirx; PrimExpr min = start; - PrimExpr extent = arith::Analyzer().Simplify(stop - start); + PrimExpr extent = arith::Analyzer()->Simplify(stop - start); ffi::ObjectPtr n = ffi::make_object(); int bits = std::max(min.dtype().bits(), extent.dtype().bits()); DataType dtype = DataType(min.dtype().code(), bits, 1); @@ -638,7 +638,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { if (!iter_var->dom.defined()) { const_cast(iter_var.get())->dom = Range(tvm::tirx::make_zero(extent.dtype()), extent); - } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { + } else if (!arith::Analyzer()->CanProveEqual(iter_var->dom->extent, extent)) { TVM_FFI_THROW(InternalError) << "ValueError: Inconsistent extents of environment thread. " << iter_var->dom->extent << " vs " << extent; } diff --git a/src/tirx/transform/flatten_buffer.cc b/src/tirx/transform/flatten_buffer.cc index 485f3347f280..7298c2df2092 100644 --- a/src/tirx/transform/flatten_buffer.cc +++ b/src/tirx/transform/flatten_buffer.cc @@ -45,7 +45,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { public: static PrimFunc Flatten(PrimFunc func) { arith::Analyzer ana; - auto pass = BufferFlattener(&ana); + auto pass = BufferFlattener(ana.get()); pass.MarkBufferMapShapes(func); auto body = pass.VisitStmt(func->body); @@ -78,7 +78,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt; using IRMutatorWithAnalyzer::VisitStmt_; - explicit BufferFlattener(arith::Analyzer* ana) : IRMutatorWithAnalyzer(ana) {} + explicit BufferFlattener(arith::AnalyzerObj* ana) : IRMutatorWithAnalyzer(ana) {} Stmt VisitStmt_(const SBlockNode* op) final { TVM_FFI_ICHECK_EQ(op->match_buffers.size(), 0) diff --git a/src/tirx/transform/ir_utils.cc b/src/tirx/transform/ir_utils.cc index 281e53d76c4c..93aba86aac1b 100644 --- a/src/tirx/transform/ir_utils.cc +++ b/src/tirx/transform/ir_utils.cc @@ -677,8 +677,8 @@ ffi::Array GetBufferAllocationShape(const Buffer& buffer) { if (buffer->strides.size()) { TVM_FFI_ICHECK_EQ(buffer->shape.size(), buffer->strides.size()); for (size_t i = buffer->strides.size() - 1; i > 0; --i) { - TVM_FFI_ICHECK( - arith::Analyzer().CanProveEqual(floormod(buffer->strides[i - 1], buffer->strides[i]), 0)); + TVM_FFI_ICHECK(arith::Analyzer()->CanProveEqual( + floormod(buffer->strides[i - 1], buffer->strides[i]), 0)); alloc_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); } } @@ -697,7 +697,7 @@ ffi::Array ConvertIndices(const MatchBufferRegion& match_buffer, size_t offset = source->region.size() - indices.size(); for (size_t i = 0; i < offset; ++i) { const Range& range = source->region[i]; - TVM_FFI_ICHECK(analyzer.CanProve(range->extent == 1)); + TVM_FFI_ICHECK(analyzer->CanProve(range->extent == 1)); result.push_back(range->min); } for (size_t i = 0; i < indices.size(); ++i) { @@ -719,7 +719,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region size_t offset = source->region.size() - region.size(); for (size_t i = 0; i < offset; ++i) { const Range& source_range = source->region[i]; - TVM_FFI_ICHECK(analyzer.CanProve(source_range->extent == 1)); + TVM_FFI_ICHECK(analyzer->CanProve(source_range->extent == 1)); result.push_back(Range::FromMinExtent(source_range->min, 1)); } for (size_t i = 0; i < region.size(); ++i) { @@ -735,7 +735,7 @@ ffi::Optional ConditionalBoundsContext::TrySolveCondition // extract equations and related vars from condition expression. // currently only extract simple integral equations which could be solvable. arith::Analyzer analyzer; - PrimExpr condition = analyzer.Simplify(condition_); + PrimExpr condition = analyzer->Simplify(condition_); if (is_const_int(condition)) { return std::nullopt; } @@ -797,7 +797,7 @@ ffi::Optional ConditionalBoundsContext::TrySolveCondition } } if (dom.defined()) { - ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer.Simplify(dom.max() - dom.min() + 1))); + ranges.Set(v, Range::FromMinExtent(dom.min(), analyzer->Simplify(dom.max() - dom.min() + 1))); } } // solve constraints diff --git a/src/tirx/transform/lower_intrin.cc b/src/tirx/transform/lower_intrin.cc index 8de8fa442216..0b859ef9956f 100644 --- a/src/tirx/transform/lower_intrin.cc +++ b/src/tirx/transform/lower_intrin.cc @@ -46,7 +46,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitStmt_; using FLowerGeneral = ffi::TypedFunction; - IntrinInjecter(arith::Analyzer* analyzer, const Target& tgt, bool enable_fast_math) + IntrinInjecter(arith::AnalyzerObj* analyzer, const Target& tgt, bool enable_fast_math) : IRMutatorWithAnalyzer(analyzer) { std::string target = tgt->kind->name; ffi::String mtriple = tgt->GetAttr("mtriple").value_or(""); @@ -366,7 +366,8 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) { arith::Analyzer analyzer; bool enable_fast_math = transform::PassContext::Current()->GetConfig("tirx.enable_fast_math", false).value(); - return IntrinInjecter(&analyzer, Target(ffi::String(target)), enable_fast_math)(std::move(stmt)); + return IntrinInjecter(analyzer.get(), Target(ffi::String(target)), + enable_fast_math)(std::move(stmt)); } namespace transform { @@ -378,7 +379,7 @@ Pass LowerIntrin() { TVM_FFI_ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; bool enable_fast_math = ctx->GetConfig("tirx.enable_fast_math", false).value(); - n->body = IntrinInjecter(&analyzer, target.value(), enable_fast_math)(std::move(n->body)); + n->body = IntrinInjecter(analyzer.get(), target.value(), enable_fast_math)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {}); diff --git a/src/tirx/transform/lower_tirx_cleanup.cc b/src/tirx/transform/lower_tirx_cleanup.cc index 318631fc939e..b192b110fbab 100644 --- a/src/tirx/transform/lower_tirx_cleanup.cc +++ b/src/tirx/transform/lower_tirx_cleanup.cc @@ -76,7 +76,7 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer { static std::pair> Flatten( const Stmt& stmt, const ffi::Map buffer_map, const Target& target) { arith::Analyzer ana; - LayoutApplier storage_lower(&ana, target); + LayoutApplier storage_lower(ana.get(), target); std::unordered_map new_buffer_map; std::vector param_flattened_buffers; for (const auto& kv : buffer_map) { @@ -101,7 +101,7 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer { using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; - explicit LayoutApplier(arith::Analyzer* analyzer, const Target& target) + explicit LayoutApplier(arith::AnalyzerObj* analyzer, const Target& target) : arith::IRMutatorWithAnalyzer(analyzer), target_(target) {} ffi::Any VisitAny(const ffi::Any& any) { @@ -187,7 +187,7 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer { } flattened = buf; writer = flattened.CopyOnWrite(); - writer->shape = {ana.Simplify(mem_span)}; + writer->shape = {ana->Simplify(mem_span)}; writer->strides = {}; writer->axis_separators = {}; } else { diff --git a/src/tirx/transform/lower_warp_memory.cc b/src/tirx/transform/lower_warp_memory.cc index 99c815bf6630..37c9fec6b72e 100644 --- a/src/tirx/transform/lower_warp_memory.cc +++ b/src/tirx/transform/lower_warp_memory.cc @@ -106,7 +106,7 @@ namespace tirx { // store warp_mem[m * warp_index + (width * m) * y + x] class WarpStoreCoeffFinder : private StmtExprVisitor { public: - WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) + WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::AnalyzerObj* analyzer) : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { @@ -193,7 +193,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { // the coefficient int64_t warp_coeff_{0}; // analyzer. - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; }; // Visitor to find the warp index @@ -244,7 +244,7 @@ class WarpIndexFinder : private StmtVisitor { // Mutator to change the read pattern class WarpAccessRewriter : protected StmtExprMutator { public: - explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer) + explicit WarpAccessRewriter(int warp_size, arith::AnalyzerObj* analyzer) : warp_size_(warp_size), analyzer_(analyzer) {} // Rewrite the AllocBuffer statement which transforms // warp memory to local memory. @@ -427,7 +427,7 @@ class WarpAccessRewriter : protected StmtExprMutator { // the coefficient n int warp_group_{0}; // Internal analyzer - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; }; // Bind bound information of variables to make analyzer more effective @@ -435,7 +435,7 @@ class WarpAccessRewriter : protected StmtExprMutator { // so analysis can be context independent. class BindVarBoundInfo : public StmtVisitor { public: - explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} + explicit BindVarBoundInfo(arith::AnalyzerObj* analyzer) : analyzer_(analyzer) {} void VisitStmt_(const ForNode* op) final { const Var& loop_var = op->loop_var; @@ -458,7 +458,7 @@ class BindVarBoundInfo : public StmtVisitor { protected: // internal analyzer. - arith::Analyzer* analyzer_; + arith::AnalyzerObj* analyzer_; // variable domain std::unordered_map var_dom_; }; @@ -470,7 +470,7 @@ class WarpMemoryRewriter : private StmtMutator { Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; - BindVarBoundInfo binder(&analyzer_); + BindVarBoundInfo binder(analyzer_.get()); binder(stmt); stmt = operator()(std::move(stmt)); return stmt; @@ -493,7 +493,7 @@ class WarpMemoryRewriter : private StmtMutator { remaining.push_back(op->seq[j]); } Stmt body = remaining.empty() ? Stmt(Evaluate(0)) : SeqStmt::Flatten(remaining); - WarpAccessRewriter rewriter(warp_size_, &analyzer_); + WarpAccessRewriter rewriter(warp_size_, analyzer_.get()); Stmt rewritten = rewriter.Rewrite(alloc, body); new_seq.push_back(rewritten); changed = true; diff --git a/src/tirx/transform/narrow_datatype.cc b/src/tirx/transform/narrow_datatype.cc index 771ea674f22a..b1816e9d9962 100644 --- a/src/tirx/transform/narrow_datatype.cc +++ b/src/tirx/transform/narrow_datatype.cc @@ -82,7 +82,7 @@ class DataTypeVisitor final : public StmtExprVisitor { if (e.dtype().is_int()) { int bits = max_bits_; if (bound_.find(e) == bound_.end()) { - analyzer_.const_int_bound(e, &bound_); + analyzer_->const_int_bound(e, &bound_); } ConstIntBound bound = bound_[e]; int64_t ubound = Downcast(max_value(DataType::Int(target_bits_)))->value; @@ -108,14 +108,14 @@ class DataTypeVisitor final : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) { - analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); vextent_[op->loop_var.as()] = op->extent.dtype(); return StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const SBlockNode* op) { for (const IterVar& iter : op->iter_vars) { - analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); + analyzer_->Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); vextent_[iter->var.as()] = iter->dom->extent.dtype(); } StmtExprVisitor::VisitStmt_(op); @@ -125,7 +125,7 @@ class DataTypeVisitor final : public StmtExprVisitor { if (op->attr_key == attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - analyzer_.Bind(iv->var, Range::FromMinExtent(0, op->value)); + analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); vextent_[iv->var.as()] = op->value.dtype(); StmtExprVisitor::VisitStmt_(op); } else { @@ -136,7 +136,7 @@ class DataTypeVisitor final : public StmtExprVisitor { void VisitExpr_(const ReduceNode* op) { // Setup the domain information before simplification. for (const IterVar& iv : op->axis) { - analyzer_.Bind(iv->var, iv->dom); + analyzer_->Bind(iv->var, iv->dom); vextent_[iv->var.as()] = iv->dom->extent.dtype(); } // Recursively call simplification when necessary. diff --git a/src/tirx/transform/remove_no_op.cc b/src/tirx/transform/remove_no_op.cc index 133cfa9d9a56..8ae06ea9a37b 100644 --- a/src/tirx/transform/remove_no_op.cc +++ b/src/tirx/transform/remove_no_op.cc @@ -74,7 +74,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.RemoveNoOp", RemoveNoOpConfig); // Mark the statement of each stage. class NoOpRemover : public arith::IRMutatorWithAnalyzer { public: - static Stmt Apply(Stmt stmt, arith::Analyzer* analyzer, bool ignore_profiler_call = false) { + static Stmt Apply(Stmt stmt, arith::AnalyzerObj* analyzer, bool ignore_profiler_call = false) { NoOpRemover visitor(analyzer, ignore_profiler_call); return visitor(std::move(stmt)); } @@ -84,7 +84,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { using Parent::VisitStmt; using Parent::VisitStmt_; - NoOpRemover(arith::Analyzer* analyzer, bool ignore_profiler_call = false) + NoOpRemover(arith::AnalyzerObj* analyzer, bool ignore_profiler_call = false) : Parent(analyzer), ignore_profiler_call_(ignore_profiler_call) {} Stmt VisitStmt_(const BindNode* op) final { @@ -99,7 +99,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { auto wait_attrs = GetAsyncWaitAttributes(op); auto wait_cnt = wait_attrs.second; arith::Analyzer ana; - if (ana.CanProve(wait_cnt < 0)) { + if (ana->CanProve(wait_cnt < 0)) { // A negative wait count can arise if it depends on a loop variable. // For example, a wait count 1 - i can be negative after loop unrolling. // We assume that such wait is a nop. @@ -263,7 +263,7 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { bool ignore_profiler_call_{false}; }; -Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, bool ignore_profiler_call) { +Stmt RemoveNoOp(Stmt stmt, arith::AnalyzerObj* analyzer, bool ignore_profiler_call) { return NoOpRemover::Apply(std::move(stmt), analyzer, ignore_profiler_call); } @@ -276,14 +276,14 @@ Pass RemoveNoOp() { .value_or(tvm::transform::PassConfigWithDefaults()); arith::Analyzer analyzer; - analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps); + analyzer->rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps); bool ignore_profiler_call = config->ignore_profiler_call; { auto* write_ptr = f.CopyOnWrite(); write_ptr->body = - NoOpRemover::Apply(std::move(write_ptr->body), &analyzer, ignore_profiler_call); + NoOpRemover::Apply(std::move(write_ptr->body), analyzer.get(), ignore_profiler_call); } return f; }; diff --git a/src/tirx/transform/remove_no_op.h b/src/tirx/transform/remove_no_op.h index 21d1f917d50b..cd9710b61791 100644 --- a/src/tirx/transform/remove_no_op.h +++ b/src/tirx/transform/remove_no_op.h @@ -41,7 +41,7 @@ namespace tirx { * * \return The modified statement with no-ops removed */ -Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, bool ignore_profiler_call = false); +Stmt RemoveNoOp(Stmt stmt, arith::AnalyzerObj* analyzer, bool ignore_profiler_call = false); } // namespace tirx } // namespace tvm diff --git a/src/tirx/transform/stmt_simplify.cc b/src/tirx/transform/stmt_simplify.cc index 9ebbcab9e133..d7dd4599f4fc 100644 --- a/src/tirx/transform/stmt_simplify.cc +++ b/src/tirx/transform/stmt_simplify.cc @@ -98,7 +98,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tirx.StmtSimplify", StmtSimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: - static PrimFunc Apply(PrimFunc func, Analyzer* analyzer, + static PrimFunc Apply(PrimFunc func, AnalyzerObj* analyzer, ffi::Optional config_opt = std::nullopt) { auto config = config_opt.value_or(MakeDefaultStmtSimplifyConfig()); analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); @@ -110,7 +110,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } private: - explicit StmtSimplifier(Analyzer* analyzer, StmtSimplifyConfig config) + explicit StmtSimplifier(AnalyzerObj* analyzer, StmtSimplifyConfig config) : IRMutatorWithAnalyzer(analyzer), config_(config) {} using Parent = IRMutatorWithAnalyzer; @@ -250,7 +250,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { namespace tirx { -PrimFunc StmtSimplify(PrimFunc func, arith::Analyzer* analyzer) { +PrimFunc StmtSimplify(PrimFunc func, arith::AnalyzerObj* analyzer) { return arith::StmtSimplifier::Apply(std::move(func), analyzer); } @@ -261,7 +261,7 @@ Pass StmtSimplify() { arith::Analyzer analyzer; auto cfg = ctx->GetConfig("tirx.StmtSimplify"); - return arith::StmtSimplifier::Apply(f, &analyzer, cfg); + return arith::StmtSimplifier::Apply(f, analyzer.get(), cfg); }; return CreatePrimFuncPass(pass_func, 0, "tirx.StmtSimplify", {}); } diff --git a/src/tirx/transform/stmt_simplify.h b/src/tirx/transform/stmt_simplify.h index 2e5e9b48cabb..5f10397e839a 100644 --- a/src/tirx/transform/stmt_simplify.h +++ b/src/tirx/transform/stmt_simplify.h @@ -34,7 +34,7 @@ namespace tirx { * * Applies the same behavior as the tirx.transform.StmtSimplify pass. */ -PrimFunc StmtSimplify(PrimFunc func, arith::Analyzer* analyzer); +PrimFunc StmtSimplify(PrimFunc func, arith::AnalyzerObj* analyzer); } // namespace tirx } // namespace tvm diff --git a/src/tirx/transform/storage_rewrite.cc b/src/tirx/transform/storage_rewrite.cc index 66d04ca89997..ddc16b0394e4 100644 --- a/src/tirx/transform/storage_rewrite.cc +++ b/src/tirx/transform/storage_rewrite.cc @@ -745,13 +745,13 @@ class StoragePlanRewriter : public StmtExprMutator { } // transform to alloc bytes auto type_bits = alloc_type.bits() * alloc_type.lanes(); - bool divided = analyzer_.CanProve(indexmod(combo_size, type_bits) == 0); + bool divided = analyzer_->CanProve(indexmod(combo_size, type_bits) == 0); combo_size = indexdiv(combo_size, type_bits); // round up for can not divided if (!divided) { combo_size = combo_size + make_const(DataType::Int(32), 1); } - combo_size = analyzer_.Simplify(combo_size); + combo_size = analyzer_->Simplify(combo_size); Buffer buf(e->alloc_var, alloc_type, {combo_size}, {}, PrimExpr(), e->alloc_var->name_hint, 0, 0, BufferType::kDefault); ffi::Map annotations; @@ -1171,7 +1171,7 @@ struct BufferVarInfo { } } arith::Analyzer analyzer_; - arith::ModularSet me = analyzer_.modular_set(extent); + arith::ModularSet me = analyzer_->modular_set(extent); if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { preferred_lanes = lanes; } @@ -1389,7 +1389,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { if (ramp_index && is_one(ramp_index->stride)) { if (ramp_index->lanes->IsInstance()) { int lanes = static_cast(Downcast(ramp_index->lanes)->value); - arith::ModularSet me = analyzer_.modular_set(ramp_index->base); + arith::ModularSet me = analyzer_->modular_set(ramp_index->base); if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) { lanes_used = lanes; } @@ -1400,7 +1400,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { if (detect_scalar_read_patterns_ && is_buffer_load && indices.size()) { const PrimExpr last_dim_index = indices[indices.size() - 1]; if (last_dim_index.dtype().lanes() == 1) { - arith::ModularSet me = analyzer_.modular_set(last_dim_index); + arith::ModularSet me = analyzer_->modular_set(last_dim_index); var_info.scalar_read_dtype.emplace(access_dtype.with_lanes(me->coeff)); return; } @@ -1539,7 +1539,7 @@ class VectorTypeRewriter : public StmtExprMutator { } indices.Set(indices.size() - 1, new_index); } else if (last_dim_index.dtype().lanes() == 1 && info.factor() > 1) { - arith::ModularSet me = analyzer_.modular_set(last_dim_index); + arith::ModularSet me = analyzer_->modular_set(last_dim_index); TVM_FFI_ICHECK(me->coeff == 0 || info.factor() % me->coeff == 0); PrimExpr new_index = last_dim_index / make_const(last_dim_index.dtype(), info.factor()); shuffle_index = me->base % info.factor(); diff --git a/src/tirx/transform/tile_primitive_dispatch.cc b/src/tirx/transform/tile_primitive_dispatch.cc index 0e0e4932caf4..1a97db156b0c 100644 --- a/src/tirx/transform/tile_primitive_dispatch.cc +++ b/src/tirx/transform/tile_primitive_dispatch.cc @@ -1033,14 +1033,14 @@ class TilePrimitiveDispatcher : public StmtExprMutator { bool TryExtractLinearScopeDiff(const PrimExpr& diff, ScopeIdTarget* target, int64_t* coeff, int64_t* base) { - PrimExpr simplified = analyzer_.Simplify(diff); + PrimExpr simplified = analyzer_->Simplify(diff); for (const auto& [var, candidate] : ScopeIdTargets()) { ffi::Array linear = arith::DetectLinearEquation(simplified, {var}); if (linear.size() != 2) continue; int64_t c = 0; int64_t b = 0; - if (!TryExtractIntImm(analyzer_.Simplify(linear[0]), &c) || - !TryExtractIntImm(analyzer_.Simplify(linear[1]), &b)) { + if (!TryExtractIntImm(analyzer_->Simplify(linear[0]), &c) || + !TryExtractIntImm(analyzer_->Simplify(linear[1]), &b)) { continue; } if (c != 1 && c != -1) continue; @@ -1126,7 +1126,7 @@ class TilePrimitiveDispatcher : public StmtExprMutator { auto maybe_target = ResolveScopeIdTarget(lhs); if (!maybe_target) return false; int64_t mod_value = 0; - if (!TryExtractIntImm(analyzer_.Simplify(rhs), &mod_value) || mod_value <= 0) return false; + if (!TryExtractIntImm(analyzer_->Simplify(rhs), &mod_value) || mod_value <= 0) return false; *target = *maybe_target; *modulus = mod_value; return true; @@ -1137,11 +1137,11 @@ class TilePrimitiveDispatcher : public StmtExprMutator { int64_t modulus = 0; int64_t residue = 0; if (TryExtractModuloTarget(lhs, &target, &modulus) && - TryExtractIntImm(analyzer_.Simplify(rhs), &residue)) { + TryExtractIntImm(analyzer_->Simplify(rhs), &residue)) { return TryPushModuloForTarget(target, modulus, residue); } if (TryExtractModuloTarget(rhs, &target, &modulus) && - TryExtractIntImm(analyzer_.Simplify(lhs), &residue)) { + TryExtractIntImm(analyzer_->Simplify(lhs), &residue)) { return TryPushModuloForTarget(target, modulus, residue); } return false; diff --git a/src/tirx/transform/tvm_ffi_binder.cc b/src/tirx/transform/tvm_ffi_binder.cc index 16b7eab7af2c..3c37b51b59be 100644 --- a/src/tirx/transform/tvm_ffi_binder.cc +++ b/src/tirx/transform/tvm_ffi_binder.cc @@ -179,7 +179,7 @@ bool TVMFFIABIBuilder::BindScalar(const PrimExpr& arg, const PrimExpr& value, } else { // Duplicate bind: create rich assertion with both paths PrimExpr prev_value = it->second.value; - PrimExpr scond = analyzer_.Simplify(prev_value == value); + PrimExpr scond = analyzer_->Simplify(prev_value == value); if (is_zero(scond)) { TVM_FFI_THROW(InternalError) << "Bind have an unmet assertion: " << prev_value << " == " << value << " at " << RenderAccessPath(path); @@ -209,7 +209,7 @@ bool TVMFFIABIBuilder::BindScalar(const PrimExpr& arg, const PrimExpr& value, } else { // Non-Var expression (e.g. batch_size + 1): defer assertion to Finalize() // so display-var substitution can render human-readable names. - PrimExpr scond = analyzer_.Simplify(arg == value); + PrimExpr scond = analyzer_->Simplify(arg == value); if (is_zero(scond)) { TVM_FFI_THROW(InternalError) << "Bind have an unmet assertion: " << arg << " == " << value << " at " << RenderAccessPath(path); @@ -370,7 +370,7 @@ void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, PrimExpr offset = value->elem_offset; PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - PrimExpr acond = analyzer_.Simplify(truncmod(offset, factor) == zero); + PrimExpr acond = analyzer_->Simplify(truncmod(offset, factor) == zero); if (is_zero(acond)) { TVM_FFI_THROW(InternalError) << "Bind have an unmet assertion at " << RenderAccessPath(offset_path); @@ -394,7 +394,7 @@ void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, TVM_FFI_ICHECK(fuzzy_match) << "Buffer size mismatch at " << RenderAccessPath(base_path); size_t diff = value->shape.size() - arg->shape.size(); for (size_t i = 0; i < diff; ++i) { - TVM_FFI_ICHECK(is_one(analyzer_.Simplify(value->shape[i]))) + TVM_FFI_ICHECK(is_one(analyzer_->Simplify(value->shape[i]))) << "Buffer shape mismatch at " << RenderAccessPath(base_path) << ": " << arg->shape << " vs " << value->shape; } @@ -613,7 +613,7 @@ void TVMFFIABIBuilder::BindAutoBroadcastStrides(const Buffer& buffer, const Var& ffi::reflection::AccessPath strides_k_path = param_path->Attr(ffi::String("strides"))->ArrayItem(k); BindScalar(buffer->strides[k], value, strides_k_path, true); - stride = analyzer_.Simplify(stride * buffer->shape[k]); + stride = analyzer_->Simplify(stride * buffer->shape[k]); } } @@ -715,7 +715,7 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& PrimExpr offset = buffer->elem_offset; PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - PrimExpr acond = analyzer_.Simplify(truncmod(offset, factor) == zero); + PrimExpr acond = analyzer_->Simplify(truncmod(offset, factor) == zero); if (is_zero(acond)) { TVM_FFI_THROW(InternalError) << "Bind have an unmet assertion at " << RenderAccessPath(byte_offset_path); @@ -737,7 +737,7 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& // Use custom assertion for device_type to show human-readable device name if (const auto* const_dt = device_type_.as()) { PrimExpr cond = - analyzer_.Simplify(make_const(DataType::Int(32), const_dt->value) == actual_device_type); + analyzer_->Simplify(make_const(DataType::Int(32), const_dt->value) == actual_device_type); if (!is_one(cond)) { std::string device_name = runtime::DLDeviceType2Str(static_cast(const_dt->value)); EmitAssert(cond, "ValueError", // diff --git a/src/tirx/transform/unroll_loop.cc b/src/tirx/transform/unroll_loop.cc index 4a6beae92f0f..860a71d0043c 100644 --- a/src/tirx/transform/unroll_loop.cc +++ b/src/tirx/transform/unroll_loop.cc @@ -236,7 +236,7 @@ class LoopUnroller : public StmtExprMutator { // returns the extent of the loop if it's a constant integer, otherwise return -1 int GetExtent(const ForNode* op) { // constant folding. - PrimExpr extent = analyzer_.Simplify(op->extent); + PrimExpr extent = analyzer_->Simplify(op->extent); const IntImmNode* v1 = extent.as(); int value = -1; // integers that do not fit in int32_t are treated as symbolic, diff --git a/src/tirx/transform/vectorize_loop.cc b/src/tirx/transform/vectorize_loop.cc index 540e641bdff1..a1e954f95184 100644 --- a/src/tirx/transform/vectorize_loop.cc +++ b/src/tirx/transform/vectorize_loop.cc @@ -258,7 +258,7 @@ class VecAllocAccess : public StmtExprMutator { // var_lanes_. Typically, this will be a 1-d index into a flat // memory space. ffi::Array shape = node->buffer->shape; - shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); + shape.Set(shape.size() - 1, analyzer_->Simplify(shape[shape.size() - 1] * var_lanes_)); // TODO(Lunderberg): Move this pass to be prior to // FlattenBuffer, implement by appending a @@ -273,7 +273,7 @@ class VecAllocAccess : public StmtExprMutator { if (i != strides.size() - 1) { stride *= var_lanes_; } - strides.push_back(analyzer_.Simplify(stride)); + strides.push_back(analyzer_->Simplify(stride)); } // Copy everything into the new buffer. @@ -288,7 +288,7 @@ class VecAllocAccess : public StmtExprMutator { // variable. ffi::Array indices = node->indices; indices.Set(indices.size() - 1, - analyzer_.Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); + analyzer_->Simplify(indices[indices.size() - 1] * var_lanes_ + var_)); auto writer = node.CopyOnWrite(); writer->buffer = buf; @@ -358,11 +358,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor(); const RampNode* a_ramp = a.as(); - if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) { + if (a_ramp && b.dtype().is_scalar() && analyzer_->CanProve(b > 0)) { PrimExpr lanes = a_ramp->lanes; return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes); } - if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) { + if (b_ramp && a.dtype().is_scalar() && analyzer_->CanProve(a > 0)) { PrimExpr lanes = b_ramp->lanes; return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes); } @@ -412,8 +412,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor(); int op_lanes = static_cast(Downcast(op->lanes)->value); int base_ramp_lanes = static_cast(Downcast(base_ramp->lanes)->value); - if (analyzer_.CanProve(base_ramp->stride == - stride * make_const(stride.dtype(), base_ramp_lanes))) { + if (analyzer_->CanProve(base_ramp->stride == + stride * make_const(stride.dtype(), base_ramp_lanes))) { return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes); } } diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 2c7b9cea2472..a39a39149cd2 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -27,11 +27,11 @@ TEST(Simplify, MinMax) { tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)); - auto e1s = ana.canonical_simplify(e1); + auto e1s = ana->canonical_simplify(e1); TVM_FFI_ICHECK(tvm::tirx::is_zero(e1s)); auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); - auto e2s = ana.canonical_simplify(e2); + auto e2s = ana->canonical_simplify(e2); TVM_FFI_ICHECK(tvm::tirx::is_zero(e2s)); } @@ -39,7 +39,7 @@ TEST(Simplify, Mul) { tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); auto e = (x * x) - (x * x); - auto es = ana.canonical_simplify(e); + auto es = ana->canonical_simplify(e); TVM_FFI_ICHECK(tvm::tirx::is_zero(es)); } @@ -50,11 +50,31 @@ TEST(Simplify, Mod) { // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = ana.canonical_simplify(tvm::tirx::Mod(x, y)); - auto es = ana.canonical_simplify(mod - x); + auto mod = ana->canonical_simplify(tvm::tirx::Mod(x, y)); + auto es = ana->canonical_simplify(mod - x); TVM_FFI_ICHECK(tvm::tirx::is_zero(es)); } +TEST(AnalyzerObjectRef, CopySharesMutableState) { + tvm::arith::Analyzer analyzer; + tvm::arith::Analyzer copy = analyzer; + auto x = tvm::te::var("x"); + + copy->Bind(x, tvm::Range::FromMinExtent(0, 8)); + + TVM_FFI_ICHECK(analyzer->CanProve(x < 8)); +} + +TEST(AnalyzerObjectRef, ConstHandleRefCanMutateAnalyzerState) { + tvm::arith::Analyzer analyzer; + const tvm::arith::Analyzer& analyzer_ref = analyzer; + auto x = tvm::te::var("x"); + + analyzer_ref->Bind(x, tvm::Range::FromMinExtent(0, 8)); + + TVM_FFI_ICHECK(analyzer->CanProve(x < 8)); +} + TEST(ConstantFold, Broadcast) { tvm::ffi::StructuralEqual checker; auto i32x4 = tvm::tirx::Broadcast(tvm::IntImm(tvm::DataType::Int(32), 10), 4); diff --git a/tests/cpp/threading_backend_test.cc b/tests/cpp/threading_backend_test.cc index 8c30aaeb1e4e..e68e0fdba832 100644 --- a/tests/cpp/threading_backend_test.cc +++ b/tests/cpp/threading_backend_test.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include #include diff --git a/tests/python/arith/test_arith_analyzer_object.py b/tests/python/arith/test_arith_analyzer_object.py new file mode 100644 index 000000000000..a88c413b43ea --- /dev/null +++ b/tests/python/arith/test_arith_analyzer_object.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import tirx +from tvm.runtime import Object + + +def test_analyzer_is_ffi_object_with_persistent_state(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + + assert isinstance(analyzer, Object) + + analyzer.bind(x, tvm.ir.Range(0, 8)) + assert analyzer.const_int_bound_is_bound(x) + assert analyzer.can_prove(x < 8) + assert not analyzer.can_prove(x < 4) + + bound = analyzer.const_int_bound(x + 1) + assert bound.min_value == 1 + assert bound.max_value == 8 + + +def test_analyzer_object_constraint_scope_and_override_bind(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + + with analyzer.constraint_scope(x % 3 == 0): + assert analyzer.modular_set(x).coeff == 3 + + assert analyzer.modular_set(x).coeff != 3 + + analyzer = tvm.arith.Analyzer() + y = tirx.Var("y", "int64") + analyzer.bind(y, tirx.const(4, "int64")) + tvm.ir.assert_structural_equal(analyzer.simplify(y + 1), tirx.const(5, "int64")) + + analyzer.bind(y, tirx.const(8, "int64"), allow_override=True) + tvm.ir.assert_structural_equal(analyzer.simplify(y + 1), tirx.const(9, "int64")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tirx-base/test_tir_index_map.py b/tests/python/tirx-base/test_tir_index_map.py index 28b75d8f62c2..a439728b5e84 100644 --- a/tests/python/tirx-base/test_tir_index_map.py +++ b/tests/python/tirx-base/test_tir_index_map.py @@ -17,13 +17,14 @@ # ruff: noqa: E741, F401 import numpy as np import pytest +import tvm_ffi import tvm import tvm.testing from tvm.ir import assert_structural_equal from tvm.runtime import const from tvm.script import tirx as T -from tvm.tirx import IndexMap, IntImm, floordiv, floormod +from tvm.tirx import IndexMap, IntImm, floordiv, floormod, stmt_functor def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: @@ -46,6 +47,32 @@ def test_index_mapping(): assert_structural_equal(index_map.map_indices([T.int64(42)]), [T.int64(10), T.int64(2)]) +def test_map_indices_accepts_external_analyzer(): + tile = tvm.tirx.Var("tile", "int32") + index_map = IndexMap.from_func(lambda i: [i // tile], index_dtype="int32") + analyzer = tvm.arith.Analyzer() + + unsimplified = index_map.map_indices([T.int32(32)])[0] + analyzer.bind(tile, T.int32(16)) + simplified = index_map.map_indices([T.int32(32)], analyzer=analyzer)[0] + + assert not tvm_ffi.structural_equal(unsimplified, T.int32(2)) + assert_structural_equal(simplified, T.int32(2)) + + +def test_map_shape_accepts_external_analyzer(): + tile = tvm.tirx.Var("tile", "int32") + index_map = IndexMap.from_func(lambda i: [i // tile, i % tile], index_dtype="int32") + analyzer = tvm.arith.Analyzer() + + unsimplified = index_map.map_shape([T.int32(32)])[0] + analyzer.bind(tile, T.int32(16)) + simplified = index_map.map_shape([T.int32(32)], analyzer=analyzer) + + assert not tvm_ffi.structural_equal(unsimplified, T.int32(2)) + assert_structural_equal(simplified, [T.int32(2), tile]) + + def test_shape_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") @@ -64,6 +91,18 @@ def test_inverse(): assert index_map.inverse([16]).is_equivalent_to(expected_inverse) +def test_inverse_accepts_external_analyzer(): + tile = tvm.tirx.Var("tile", "int32") + index_map = IndexMap.from_func(lambda i: [i // tile, i % tile], index_dtype="int32") + analyzer = tvm.arith.Analyzer() + + analyzer.bind(tile, T.int32(16)) + inverse = index_map.inverse([T.int32(32)], analyzer=analyzer) + mapped = inverse.map_indices([T.int32(1), T.int32(3)], analyzer=analyzer) + + assert_structural_equal(mapped, [T.int32(19)]) + + def test_nonbijective_inverse_gives_error(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4]) @@ -198,6 +237,29 @@ def test_nonsurjective_inverse(padding_test_case): tvm.ir.assert_structural_equal(padding_predicate, expected_predicate) +def test_non_surjective_inverse_accepts_external_analyzer(): + tile = tvm.tirx.Var("tile", "int32") + index_map = IndexMap.from_func(lambda i: [i // tile, i % tile], index_dtype="int32") + analyzer = tvm.arith.Analyzer() + + analyzer.bind(tile, T.int32(16)) + inverse, padding_predicate = index_map.non_surjective_inverse([T.int32(31)], analyzer=analyzer) + mapped = inverse.map_indices([T.int32(1), T.int32(15)], analyzer=analyzer) + + assert_structural_equal(mapped, [T.int32(31)]) + + padding_at_last_element = stmt_functor.substitute( + padding_predicate, + {inverse.initial_indices[0]: T.int32(1), inverse.initial_indices[1]: T.int32(15)}, + ) + padding_at_first_element = stmt_functor.substitute( + padding_predicate, + {inverse.initial_indices[0]: T.int32(0), inverse.initial_indices[1]: T.int32(0)}, + ) + assert_structural_equal(analyzer.simplify(padding_at_last_element), T.bool(True)) + assert_structural_equal(analyzer.simplify(padding_at_first_element), T.bool(False)) + + def test_index_map_inverse_no_iter(): def input_example(i0, i1, i2, i3): j0 = floordiv(i3, 32) From 23af9ce883e6024542e101159db40f068c622dab Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 7 Jun 2026 21:46:52 -0400 Subject: [PATCH 2/2] [Arith] Expose more Analyzer methods, reuse a shared analyzer, and add tests Now that arith::Analyzer is a tvm-ffi object, expose more of its stateful surface to Python and let callers share a single analyzer across FFI calls. - Add Python bindings for additional Analyzer methods. - Accept optional analyzers in iter-map, int-set, and IndexMap helper APIs. - Keep IndexMap inverse temporary bindings out of caller-provided analyzers. - Add targeted tests for optional analyzer reuse and state isolation. --- python/tvm/arith/__init__.py | 2 +- python/tvm/arith/analyzer.py | 79 +++++++++- python/tvm/arith/int_set.py | 24 ++- python/tvm/arith/iter_affine_map.py | 35 +++- python/tvm/tirx/function.py | 14 +- src/arith/analyzer.cc | 26 ++- src/arith/int_set.cc | 18 +-- src/arith/iter_affine_map.cc | 25 +-- src/tirx/ir/index_map.cc | 8 +- .../arith/test_arith_analyzer_object.py | 149 ++++++++++++++++++ tests/python/arith/test_arith_intset.py | 28 ++++ .../arith/test_arith_iter_affine_map.py | 58 +++++++ tests/python/tirx-base/test_tir_index_map.py | 31 +++- 13 files changed, 446 insertions(+), 51 deletions(-) diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 84de36bb7880..b4a131cff7d0 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -25,7 +25,7 @@ estimate_region_strict_bound, estimate_region_upper_bound, ) -from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength, Extension +from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength, Extension, CompareResult from .bound import deduce_bound from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index c3b77a9603ff..0aa6a75eba4a 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -35,6 +35,22 @@ class ProofStrength(enum.IntEnum): SYMBOLIC_BOUND = 1 +class CompareResult(enum.IntEnum): + """Result of a transitive comparison. + + Values must match the C++ ``arith::CompareResult`` enum. + """ + + INCONSISTENT = 0 + EQ = 1 + LT = 2 + LE = 3 + GT = 4 + GE = 5 + NE = 6 + UNKNOWN = 7 + + class Extension(enum.Flag): """Extensions enabled for RewriteSimplifier @@ -214,7 +230,7 @@ def canonical_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: """ return _ffi_api.AnalyzerCanonicalSimplify(self, expr) - def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet]) -> IntSet: + def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet] | None = None) -> IntSet: """Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters @@ -222,8 +238,9 @@ def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet]) -> IntSe expr : PrimExpr The expression. - dom_map : Dict[tvm.tirx.Var, tvm.arith.IntSet] - The domain for variables to be relaxed. + dom_map : Optional[Dict[tvm.tirx.Var, tvm.arith.IntSet]] + The domain for variables to be relaxed. When omitted, the analyzer + uses the domains of the variables already bound to it. Returns ------- @@ -252,6 +269,21 @@ def can_prove( """ return _ffi_api.AnalyzerCanProve(self, expr, strength) + def set_maximum_rewrite_steps(self, maximum: int) -> None: + """Set the maximum allowed number of rewrite-simplify steps. + + When a positive limit is set, the simplifier raises an exception once + it exceeds that number of rewrite steps. This is useful for guarding + against performance regressions in tests. + + Parameters + ---------- + maximum : int + The maximum number of rewrite steps, or a non-positive value to + allow an unlimited number of steps. + """ + _ffi_api.AnalyzerSetMaximumRewriteSteps(self, maximum) + def bind( self, var: tirx.Var, @@ -304,22 +336,30 @@ def _fenter(): return ConstraintScope(_fenter) - def update(self, var: tirx.Var, info: ConstIntBound, override: bool = False) -> None: - """Update infomation about var + def update( + self, var: tirx.Var, info: ConstIntBound | ModularSet | IntSet, override: bool = False + ) -> None: + """Update information about var. Parameters ---------- var : tvm.tirx.Var The variable. - info : tvm.Object - Related information. + info : Union[ConstIntBound, ModularSet, IntSet] + Related information. A ``ConstIntBound`` updates the constant + integer bound, a ``ModularSet`` updates the modular set, and an + ``IntSet`` updates the integer-set domain of ``var``. override : bool Whether allow override. """ if isinstance(info, ConstIntBound): _ffi_api.AnalyzerConstIntBoundUpdate(self, var, info, override) + elif isinstance(info, ModularSet): + _ffi_api.AnalyzerModularSetUpdate(self, var, info, override) + elif isinstance(info, IntSet): + _ffi_api.AnalyzerIntSetUpdate(self, var, info, override) else: raise TypeError(f"Do not know how to handle type {type(info)}") @@ -341,6 +381,31 @@ def can_prove_equal(self, lhs: tirx.PrimExpr, rhs: tirx.PrimExpr) -> bool: """ return _ffi_api.AnalyzerCanProveEqual(self, lhs, rhs) + def try_compare( + self, lhs: tirx.PrimExpr, rhs: tirx.PrimExpr, propagate_inequalities: bool = True + ) -> CompareResult: + """Compare lhs and rhs using previously provided known comparisons. + + Parameters + ---------- + lhs : PrimExpr + The left-hand side of the comparison. + + rhs : PrimExpr + The right-hand side of the comparison. + + propagate_inequalities : bool + If true, attempt to find a sequence of transitive inequalities that + allow lhs and rhs to be compared. + + Returns + ------- + result : CompareResult + The most specific result that can be proven about the comparison. + Returns ``CompareResult.UNKNOWN`` when nothing can be proven. + """ + return CompareResult(_ffi_api.AnalyzerTryCompare(self, lhs, rhs, propagate_inequalities)) + @property def enabled_extensions(self) -> Extension: """Return the currently enabled extensions""" diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index 9aad8ccfa576..00e2030a4525 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -93,7 +93,7 @@ def __init__(self): self.__init_handle_by_constructor__(_ffi_api.PresburgerSet) -def estimate_region_lower_bound(region, var_dom, predicate): +def estimate_region_lower_bound(region, var_dom, predicate, analyzer=None): """Analyze the region with affine map, given the domain of variables and their predicate Some subregion may be discarded during the lower-bound analysis. @@ -108,15 +108,19 @@ def estimate_region_lower_bound(region, var_dom, predicate): predicate : PrimExpr The predicate for the affine map + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use. When provided, its accumulated bindings and + constraints are reused; otherwise a fresh analyzer is created. + Returns ---------- region_int_set : Optional[List[IntSet]] None if the detection fails, or an array of IntSets as the result of analysis """ - return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate) + return _ffi_api.EstimateRegionLowerBound(region, var_dom, predicate, analyzer) -def estimate_region_strict_bound(region, var_dom, predicate): +def estimate_region_strict_bound(region, var_dom, predicate, analyzer=None): """Analyze the region with affine map, given the domain of variables and their predicate The result should be strict, i.e. no region is discarded or relaxed. @@ -131,15 +135,19 @@ def estimate_region_strict_bound(region, var_dom, predicate): predicate : PrimExpr The predicate for the affine map + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use. When provided, its accumulated bindings and + constraints are reused; otherwise a fresh analyzer is created. + Returns ---------- region_int_set : Optional[List[IntSet]] None if the detection fails, or an array of IntSets as the result of analysis """ - return _ffi_api.EstimateRegionStrictBound(region, var_dom, predicate) + return _ffi_api.EstimateRegionStrictBound(region, var_dom, predicate, analyzer) -def estimate_region_upper_bound(region, var_dom, predicate): +def estimate_region_upper_bound(region, var_dom, predicate, analyzer=None): """Analyze the region with affine map, given the domain of variables and their predicate Relaxation of the region may be used in upper-bound analysis, i.e. some extra region may be added to the result. @@ -155,12 +163,16 @@ def estimate_region_upper_bound(region, var_dom, predicate): predicate : PrimExpr The predicate for the affine map + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use. When provided, its accumulated bindings and + constraints are reused; otherwise a fresh analyzer is created. + Returns ---------- region_int_set : List[IntSet] an array of IntSets as the result of analysis """ - return _ffi_api.EstimateRegionUpperBound(region, var_dom, predicate) + return _ffi_api.EstimateRegionUpperBound(region, var_dom, predicate, analyzer) def pos_inf(): diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 0dae45c1a55e..0c0a3b310b05 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -129,6 +129,7 @@ def detect_iter_map( predicate=True, check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True, + analyzer=None, ): """Detect if indices can be written as mapped iters from input iters @@ -150,6 +151,10 @@ def detect_iter_map( If true, iterators with extent of 1 will be replaced with a constant value. + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use. When provided, its accumulated bindings and + constraints are reused; otherwise a fresh analyzer is created. + Returns ------- results : IterMapResult @@ -162,11 +167,11 @@ def detect_iter_map( elif check_level is None: check_level = IterMapLevel.NoCheck return _ffi_api.DetectIterMap( - indices, input_iters, predicate, check_level, simplify_trivial_iterators + indices, input_iters, predicate, check_level, simplify_trivial_iterators, analyzer ) -def normalize_to_iter_sum(index, input_iters): +def normalize_to_iter_sum(index, input_iters, analyzer=None): """Normalize expr to iter sum. The normalized result ensures that @@ -181,6 +186,10 @@ def normalize_to_iter_sum(index, input_iters): input_iters : Map[tvm.tirx.Var, Range] The domain of each input iterators. + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use. When provided, its accumulated bindings and + constraints are reused; otherwise a fresh analyzer is created. + Returns ------- iter_sum: IterSumExpr @@ -194,7 +203,7 @@ def normalize_to_iter_sum(index, input_iters): This function is useful to decide the stride multiplier and division factor in buffer access patterns. """ - return _ffi_api.NormalizeToIterSum(index, input_iters) + return _ffi_api.NormalizeToIterSum(index, input_iters, analyzer) def iter_map_simplify( @@ -203,6 +212,7 @@ def iter_map_simplify( predicate=True, check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True, + analyzer=None, ): """Simplify the indices using iter map detection. @@ -224,6 +234,10 @@ def iter_map_simplify( If true, iterators with extent of 1 will be replaced with a constant value. + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use. When provided, its accumulated bindings and + constraints are reused; otherwise a fresh analyzer is created. + Returns ------- results : IterMapResult @@ -236,7 +250,7 @@ def iter_map_simplify( elif check_level is None: check_level = IterMapLevel.NoCheck return _ffi_api.IterMapSimplify( - indices, input_iters, predicate, check_level, simplify_trivial_iterators + indices, input_iters, predicate, check_level, simplify_trivial_iterators, analyzer ) @@ -263,6 +277,7 @@ def subspace_divide( predicate=True, check_level=IterMapLevel.Surjective, simplify_trivial_iterators=True, + analyzer=None, ): """Detect if bindings can be written as ``[a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]`` @@ -305,6 +320,10 @@ def subspace_divide( If true, iterators with extent of 1 will be replaced with a constant value. + analyzer : Optional[tvm.arith.Analyzer] + The analyzer to use. When provided, its accumulated bindings and + constraints are reused; otherwise a fresh analyzer is created. + Returns ------- results : List[List[PrimExpr]] @@ -319,7 +338,13 @@ def subspace_divide( if isinstance(check_level, str): check_level = IterMapLevel.from_str(check_level) return _ffi_api.SubspaceDivide( - bindings, input_iters, sub_iters, predicate, check_level, simplify_trivial_iterators + bindings, + input_iters, + sub_iters, + predicate, + check_level, + simplify_trivial_iterators, + analyzer, ) diff --git a/python/tvm/tirx/function.py b/python/tvm/tirx/function.py index d0b10ca7d0f7..36b23c2eb5b3 100644 --- a/python/tvm/tirx/function.py +++ b/python/tvm/tirx/function.py @@ -426,7 +426,7 @@ def from_func_with_separators( return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators - def is_equivalent_to(self, other_map: "IndexMap") -> bool: + def is_equivalent_to(self, other_map: "IndexMap", analyzer=None) -> bool: """Return if the index maps are equivalent. Parameters @@ -435,6 +435,13 @@ def is_equivalent_to(self, other_map: "IndexMap") -> bool: The IndexMap to which the comparison should be made. + analyzer : Optional[tvm.arith.Analyzer] + + The analyzer to use while comparing the mapped indices. When + provided, its accumulated bindings and constraints are reused so + that maps that are only equivalent under those bindings can be + proven equal. + Returns ------- is_equivalent: bool @@ -447,9 +454,10 @@ def is_equivalent_to(self, other_map: "IndexMap") -> bool: if len(self.final_indices) != len(other_map.final_indices): return False - analyzer = tvm.arith.Analyzer() + if analyzer is None: + analyzer = tvm.arith.Analyzer() - mapped_other_final_indices = other_map.map_indices(self.initial_indices) + mapped_other_final_indices = other_map.map_indices(self.initial_indices, analyzer=analyzer) for self_index, other_index in zip(self.final_indices, mapped_other_final_indices): if not analyzer.can_prove_equal(self_index, other_index): return False diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 45f352c63131..cc3c73bb6207 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -268,6 +268,14 @@ TVM_FFI_STATIC_INIT_BLOCK() { }) .def("arith.AnalyzerConstIntBoundIsBound", [](Analyzer analyzer, const Var& var) { return analyzer->const_int_bound.IsBound(var); }) + .def("arith.AnalyzerModularSetUpdate", + [](Analyzer analyzer, const Var& var, const ModularSet& info, bool allow_override) { + analyzer->modular_set.Update(var, info, allow_override); + }) + .def("arith.AnalyzerIntSetUpdate", + [](Analyzer analyzer, const Var& var, const IntSet& info, bool allow_override) { + analyzer->int_set.Update(var, info, allow_override); + }) .def("arith.AnalyzerModularSet", [](Analyzer analyzer, const PrimExpr& expr) { return analyzer->modular_set(expr); }) .def("arith.AnalyzerSimplify", [](Analyzer analyzer, const PrimExpr& expr, @@ -283,8 +291,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { return analyzer->canonical_simplify(expr); }) .def("arith.AnalyzerIntSet", - [](Analyzer analyzer, const PrimExpr& expr, const ffi::Map& dom_map) { - return analyzer->int_set(expr, dom_map); + [](Analyzer analyzer, const PrimExpr& expr, + ffi::Optional> opt_dom_map) { + if (opt_dom_map.has_value()) { + return analyzer->int_set(expr, opt_dom_map.value()); + } + return analyzer->int_set(expr); }) .def_packed("arith.AnalyzerBind", [](ffi::PackedArgs args, ffi::Any* ret) { @@ -302,6 +314,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](Analyzer analyzer, const PrimExpr& expr, int strength) { return analyzer->CanProve(expr, static_cast(strength)); }) + .def("arith.AnalyzerSetMaximumRewriteSteps", + [](Analyzer analyzer, int64_t maximum) { + analyzer->rewrite_simplify.SetMaximumRewriteSteps(maximum); + }) .def("arith.AnalyzerEnterConstraintContext", [](Analyzer analyzer, const PrimExpr& constraint) { // can't use make_shared due to noexcept(false) decl in destructor, @@ -312,6 +328,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { return ffi::Function::FromPacked(fexit); }) .def_method("arith.AnalyzerCanProveEqual", &AnalyzerObj::CanProveEqual) + .def("arith.AnalyzerTryCompare", + [](Analyzer analyzer, const PrimExpr& lhs, const PrimExpr& rhs, + bool propagate_inequalities) { + return static_cast( + analyzer->transitive_comparisons.TryCompare(lhs, rhs, propagate_inequalities)); + }) .def("arith.AnalyzerGetEnabledExtensions", [](Analyzer analyzer) { return static_cast(analyzer->rewrite_simplify.GetEnabledExtensions()); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index d16a6bc7b58d..8659807cc7ea 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -1270,21 +1270,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("arith.IntSetIsNothing", &IntSet::IsNothing) .def_method("arith.IntSetIsEverything", &IntSet::IsEverything) .def("arith.EstimateRegionLowerBound", - [](ffi::Array region, ffi::Map var_dom, - PrimExpr predicate) -> ffi::Optional> { - Analyzer analyzer; + [](ffi::Array region, ffi::Map var_dom, PrimExpr predicate, + ffi::Optional opt_analyzer) -> ffi::Optional> { + Analyzer analyzer = opt_analyzer.has_value() ? opt_analyzer.value() : Analyzer(); return EstimateRegionLowerBound(region, var_dom, predicate, analyzer); }) .def("arith.EstimateRegionStrictBound", - [](ffi::Array region, ffi::Map var_dom, - PrimExpr predicate) -> ffi::Optional> { - Analyzer analyzer; + [](ffi::Array region, ffi::Map var_dom, PrimExpr predicate, + ffi::Optional opt_analyzer) -> ffi::Optional> { + Analyzer analyzer = opt_analyzer.has_value() ? opt_analyzer.value() : Analyzer(); return EstimateRegionStrictBound(region, var_dom, predicate, analyzer); }) .def("arith.EstimateRegionUpperBound", - [](ffi::Array region, ffi::Map var_dom, - PrimExpr predicate) -> ffi::Optional> { - Analyzer analyzer; + [](ffi::Array region, ffi::Map var_dom, PrimExpr predicate, + ffi::Optional opt_analyzer) -> ffi::Optional> { + Analyzer analyzer = opt_analyzer.has_value() ? opt_analyzer.value() : Analyzer(); return EstimateRegionUpperBound(region, var_dom, predicate, analyzer); }) .def("arith.PosInf", []() { return SymbolicLimits::pos_inf_; }) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 8623efa1a64d..1930feb42877 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1520,8 +1520,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def( "arith.DetectIterMap", [](const ffi::Array& indices, const ffi::Map& input_iters, - const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { - arith::Analyzer ana; + const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators, + ffi::Optional opt_analyzer) { + Analyzer ana = opt_analyzer.has_value() ? opt_analyzer.value() : Analyzer(); return DetectIterMap(indices, input_iters, input_pred, IterMapLevel(check_level), ana, simplify_trivial_iterators); }); @@ -1546,11 +1547,12 @@ IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map& input TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("arith.NormalizeToIterSum", - [](PrimExpr index, const ffi::Map& input_iters) { - arith::Analyzer ana; - return NormalizeToIterSum(index, input_iters, ana); - }); + refl::GlobalDef().def( + "arith.NormalizeToIterSum", [](PrimExpr index, const ffi::Map& input_iters, + ffi::Optional opt_analyzer) { + Analyzer ana = opt_analyzer.has_value() ? opt_analyzer.value() : Analyzer(); + return NormalizeToIterSum(index, input_iters, ana); + }); } PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { @@ -2186,8 +2188,9 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def( "arith.IterMapSimplify", [](const ffi::Array& indices, const ffi::Map& input_iters, - const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators) { - arith::Analyzer ana; + const PrimExpr& input_pred, int check_level, bool simplify_trivial_iterators, + ffi::Optional opt_analyzer) { + Analyzer ana = opt_analyzer.has_value() ? opt_analyzer.value() : Analyzer(); return IterMapSimplify(indices, input_iters, input_pred, IterMapLevel(check_level), ana, simplify_trivial_iterators); }); @@ -2526,8 +2529,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { "arith.SubspaceDivide", [](const ffi::Array& bindings, const ffi::Map& root_iters, const ffi::Array& sub_iters, const PrimExpr& predicate, int check_level, - bool simplify_trivial_iterators) { - arith::Analyzer ana; + bool simplify_trivial_iterators, ffi::Optional opt_analyzer) { + Analyzer ana = opt_analyzer.has_value() ? opt_analyzer.value() : Analyzer(); return SubspaceDivide(bindings, root_iters, sub_iters, predicate, IterMapLevel(check_level), ana, simplify_trivial_iterators); }); diff --git a/src/tirx/ir/index_map.cc b/src/tirx/ir/index_map.cc index 1e27503c082e..b26ccca248d6 100644 --- a/src/tirx/ir/index_map.cc +++ b/src/tirx/ir/index_map.cc @@ -129,13 +129,14 @@ std::pair IndexMapInverseImpl(const IndexMap& self, { TVM_FFI_ICHECK_EQ(output_ranges.size(), output_vars.size()); - arith::Analyzer analyzer; + arith::Analyzer output_var_analyzer; for (size_t i = 0; i < output_vars.size(); ++i) { - analyzer->Bind(output_vars[i], output_ranges[i]); + output_var_analyzer->Bind(output_vars[i], output_ranges[i]); } // Additional simplification steps required to unwrap nested floordiv/floormod padding_predicate = analyzer->Simplify(padding_predicate, 10); + padding_predicate = output_var_analyzer->Simplify(padding_predicate, 10); } return {IndexMap(output_vars, inverse_exprs), padding_predicate}; @@ -224,7 +225,8 @@ ffi::Array IndexMapNode::MapRanges(const ffi::Array& ranges, extent = term_extent; } } - output.push_back(Range::FromMinExtent(index->base, extent.value_or(1))); + extent = analyzer_ptr->Simplify(extent.value_or(1)); + output.push_back(Range::FromMinExtent(index->base, extent.value())); } } else { diff --git a/tests/python/arith/test_arith_analyzer_object.py b/tests/python/arith/test_arith_analyzer_object.py index a88c413b43ea..2b3931dfd97b 100644 --- a/tests/python/arith/test_arith_analyzer_object.py +++ b/tests/python/arith/test_arith_analyzer_object.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm import tirx +from tvm.arith.analyzer import CompareResult, Extension from tvm.runtime import Object @@ -55,5 +58,151 @@ def test_analyzer_object_constraint_scope_and_override_bind(): tvm.ir.assert_structural_equal(analyzer.simplify(y + 1), tirx.const(9, "int64")) +def test_analyzer_object_update_const_int_bound(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int64") + + analyzer.update(x, tvm.arith.ConstIntBound(2, 5)) + + bound = analyzer.const_int_bound(x + 1) + assert bound.min_value == 3 + assert bound.max_value == 6 + + +def test_analyzer_object_update_modular_set(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + + assert analyzer.modular_set(x).coeff == 1 + analyzer.update(x, tvm.arith.ModularSet(4, 0)) + + result = analyzer.modular_set(x) + assert result.coeff == 4 + assert result.base == 0 + + +def test_analyzer_object_update_int_set(): + analyzer = tvm.arith.Analyzer() + y = tirx.Var("y", "int32") + + analyzer.update(y, tvm.arith.IntervalSet(0, 8)) + + int_set = analyzer.int_set(y) + assert int_set.min_value.value == 0 + assert int_set.max_value.value == 8 + + +def test_analyzer_object_update_rejects_unknown_info(): + analyzer = tvm.arith.Analyzer() + y = tirx.Var("y", "int32") + + with pytest.raises(TypeError): + analyzer.update(y, "not-an-info-object") + + +def test_analyzer_object_can_prove_comparison_predicates(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 8)) + + assert analyzer.can_prove(x >= 0) + assert not analyzer.can_prove(x >= 1) + assert analyzer.can_prove(x < 8) + assert not analyzer.can_prove(x < 7) + + +def test_analyzer_object_update_const_int_bound_half_space(): + analyzer = tvm.arith.Analyzer() + n = tirx.Var("n", "int32") + + assert not analyzer.can_prove(n >= 0) + analyzer.update(n, tvm.arith.ConstIntBound(0, tvm.arith.ConstIntBound.POS_INF)) + assert analyzer.can_prove(n >= 0) + + +def test_analyzer_object_int_set_from_bound_vars(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 8)) + + int_set = analyzer.int_set(x + 1) + assert int_set.min_value.value == 1 + assert int_set.max_value.value == 8 + + +def test_analyzer_object_set_maximum_rewrite_steps(): + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + expr = (x + y) * 2 - x * 2 - y * 2 + tirx.max(x, y) - tirx.min(x, y) + + capped = tvm.arith.Analyzer() + capped.set_maximum_rewrite_steps(1) + with pytest.raises(tvm.TVMError): + capped.rewrite_simplify(expr) + + # A generous limit must not interfere with normal simplification. + relaxed = tvm.arith.Analyzer() + relaxed.set_maximum_rewrite_steps(1000) + relaxed.rewrite_simplify(expr) + + +def test_analyzer_object_try_compare_transitive(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + z = tirx.Var("z", "int32") + + assert analyzer.try_compare(x, y) == CompareResult.UNKNOWN + + with analyzer.constraint_scope(x < y): + with analyzer.constraint_scope(y < z): + # Direct known comparison. + assert analyzer.try_compare(x, y) == CompareResult.LT + # Transitive chain x < y < z is found only when propagation is enabled. + assert analyzer.try_compare(x, z) == CompareResult.LT + assert analyzer.try_compare(x, z, propagate_inequalities=False) == CompareResult.UNKNOWN + + +def test_analyzer_object_enabled_extensions_round_trip(): + analyzer = tvm.arith.Analyzer() + + assert analyzer.enabled_extensions == Extension.NoExtensions + + analyzer.enabled_extensions = Extension.ComparisonOfProductAndSum + assert analyzer.enabled_extensions == Extension.ComparisonOfProductAndSum + + analyzer.enabled_extensions = Extension.NoExtensions + assert analyzer.enabled_extensions == Extension.NoExtensions + + +def test_analyzer_object_rewrite_simplify_stats(): + analyzer = tvm.arith.Analyzer() + x = tirx.Var("x", "int32") + + analyzer.reset_rewrite_simplify_stats() + assert analyzer.rewrite_simplify_stats.nodes_visited == 0 + + analyzer.rewrite_simplify(x + 0) + assert analyzer.rewrite_simplify_stats.nodes_visited > 0 + + analyzer.reset_rewrite_simplify_stats() + assert analyzer.rewrite_simplify_stats.nodes_visited == 0 + + +def test_analyzer_object_state_persists_across_ffi_calls(): + analyzer = tvm.arith.Analyzer() + tile = tirx.Var("tile", "int32") + i = tirx.Var("i", "int32") + analyzer.bind(tile, tvm.tirx.const(8, "int32")) + + # The same analyzer object is borrowed by the C++ DetectIterMap entry point; + # its binding makes the otherwise-undetectable floormod recognizable. + result = tvm.arith.detect_iter_map([i % tile], {i: tvm.ir.Range(0, 32)}, analyzer=analyzer) + assert len(result.indices) == 1 + + # The binding still lives in the same stateful object after the FFI call. + tvm.ir.assert_structural_equal(analyzer.simplify(tile), tvm.tirx.const(8, "int32")) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/arith/test_arith_intset.py b/tests/python/arith/test_arith_intset.py index a34c528e69ef..5c8fda2700e6 100644 --- a/tests/python/arith/test_arith_intset.py +++ b/tests/python/arith/test_arith_intset.py @@ -424,5 +424,33 @@ def test_relax_cyclic_variable_dependency(): assert res is not None +def test_estimate_region_accepts_external_analyzer(): + i = tvm.tirx.Var("i", "int32") + tile = tvm.tirx.Var("tile", "int32") + region = [tvm.ir.Range.from_min_extent(i % tile, 1)] + dom = {i: tvm.ir.Range(0, 16)} + + # Without knowing `tile`, the affine detection fails for exact bounds. + assert tvm.arith.estimate_region_lower_bound(region, dom, True) is None + assert tvm.arith.estimate_region_strict_bound(region, dom, True) is None + upper_without_analyzer = tvm.arith.estimate_region_upper_bound(region, dom, True) + + analyzer = tvm.arith.Analyzer() + analyzer.bind(tile, tvm.tirx.const(4, "int32")) + # The external binding lets the affine detection succeed. + for estimate_region in [ + tvm.arith.estimate_region_lower_bound, + tvm.arith.estimate_region_strict_bound, + tvm.arith.estimate_region_upper_bound, + ]: + result = estimate_region(region, dom, True, analyzer=analyzer) + assert result is not None + assert analyzer.can_prove_equal(result[0].min_value, 0) + assert analyzer.can_prove_equal(result[0].max_value, 3) + + # The upper-bound fallback without analyzer is safe but much wider. + assert not analyzer.can_prove_equal(upper_without_analyzer[0].min_value, 0) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index 9cb4f790db08..fdbb65a0bd71 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -732,6 +732,25 @@ def test_subspace_division(): assert len(res) == 0 +def test_subspace_divide_accepts_external_analyzer(): + i = tvm.tirx.Var("i", "int32") + j = tvm.tirx.Var("j", "int32") + tile = tvm.tirx.Var("tile", "int32") + root_iters = {i: tvm.ir.Range(0, 4), j: tvm.ir.Range(0, tile)} + bindings = [j * tile + i] + + assert len(tvm.arith.subspace_divide(bindings, root_iters, [i])) == 0 + + analyzer = tvm.arith.Analyzer() + analyzer.bind(tile, T.int32(4)) + res = tvm.arith.subspace_divide(bindings, root_iters, [i], analyzer=analyzer) + res = convert_division(res) + + assert len(res) == 2 + tvm.ir.assert_structural_equal(res[0][0], j) + tvm.ir.assert_structural_equal(res[0][1], i) + + def test_subspace_divide_trivial_iters(): x = tvm.tirx.Var("x", "int32") y = tvm.tirx.Var("y", "int32") @@ -1349,6 +1368,20 @@ def test_normalize_to_iter_sum(): ) +def test_normalize_to_iter_sum_accepts_external_analyzer(): + i = tvm.tirx.Var("i", "int32") + tile = tvm.tirx.Var("tile", "int32") + input_iters = {i: tvm.ir.Range(0, 16)} + + analyzer = tvm.arith.Analyzer() + analyzer.bind(tile, T.int32(4)) + res = tvm.arith.normalize_to_iter_sum(i // tile, input_iters, analyzer=analyzer) + + assert len(res.args) == 1 + tvm.testing.assert_prim_expr_equal(res.args[0].lower_factor, tile) + tvm.testing.assert_prim_expr_equal(res.args[0].extent, T.int32(4)) + + def test_detect_iter_map_with_bufferload_recursion(): n = tvm.tirx.Var("n", "int32") m = tvm.tirx.Var("m", "int32") @@ -1369,5 +1402,30 @@ def test_detect_iter_map_with_bufferload_recursion(): assert len(result.indices) == 0 +def test_detect_iter_map_accepts_external_analyzer(): + i = tvm.tirx.Var("i", "int32") + tile = tvm.tirx.Var("tile", "int32") + iter_vars = {i: tvm.ir.Range(0, 16)} + + # Without knowing `tile`, the floormod cannot be recognized as an iterator. + assert len(tvm.arith.detect_iter_map([i % tile], iter_vars).indices) == 0 + + analyzer = tvm.arith.Analyzer() + analyzer.bind(tile, T.int32(4)) + # The external analyzer supplies `tile == 4`, allowing detection to succeed. + assert len(tvm.arith.detect_iter_map([i % tile], iter_vars, analyzer=analyzer).indices) == 1 + + +def test_iter_map_simplify_accepts_external_analyzer(): + i = tvm.tirx.Var("i", "int32") + tile = tvm.tirx.Var("tile", "int32") + iter_vars = {i: tvm.ir.Range(0, 32)} + + analyzer = tvm.arith.Analyzer() + analyzer.bind(tile, T.int32(8)) + simplified = tvm.arith.iter_map_simplify([i % tile], iter_vars, analyzer=analyzer) + tvm.ir.assert_structural_equal(simplified, [i % 8]) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tirx-base/test_tir_index_map.py b/tests/python/tirx-base/test_tir_index_map.py index a439728b5e84..539ff3480430 100644 --- a/tests/python/tirx-base/test_tir_index_map.py +++ b/tests/python/tirx-base/test_tir_index_map.py @@ -65,12 +65,23 @@ def test_map_shape_accepts_external_analyzer(): index_map = IndexMap.from_func(lambda i: [i // tile, i % tile], index_dtype="int32") analyzer = tvm.arith.Analyzer() - unsimplified = index_map.map_shape([T.int32(32)])[0] analyzer.bind(tile, T.int32(16)) - simplified = index_map.map_shape([T.int32(32)], analyzer=analyzer) + mapped_shape = index_map.map_shape([T.int32(32)], analyzer=analyzer) - assert not tvm_ffi.structural_equal(unsimplified, T.int32(2)) - assert_structural_equal(simplified, [T.int32(2), tile]) + assert_structural_equal(mapped_shape, [T.int32(2), T.int32(16)]) + + +def test_is_equivalent_to_accepts_external_analyzer(): + tile = tvm.tirx.Var("tile", "int32") + concrete = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") + symbolic = IndexMap.from_func(lambda i: [i // tile, i % tile], index_dtype="int32") + + # Without binding `tile`, the symbolic map cannot be proven equivalent. + assert not concrete.is_equivalent_to(symbolic) + + analyzer = tvm.arith.Analyzer() + analyzer.bind(tile, T.int32(4)) + assert concrete.is_equivalent_to(symbolic, analyzer=analyzer) def test_shape_mapping(): @@ -260,6 +271,18 @@ def test_non_surjective_inverse_accepts_external_analyzer(): assert_structural_equal(analyzer.simplify(padding_at_first_element), T.bool(False)) +def test_non_surjective_inverse_does_not_bind_output_vars_to_external_analyzer(): + tile = tvm.tirx.Var("tile", "int32") + index_map = IndexMap.from_func(lambda i: [i // tile, i % tile], index_dtype="int32") + analyzer = tvm.arith.Analyzer() + + analyzer.bind(tile, T.int32(16)) + inverse, _ = index_map.non_surjective_inverse([T.int32(31)], analyzer=analyzer) + + analyzer.bind(inverse.initial_indices[0], T.int32(0)) + analyzer.bind(inverse.initial_indices[1], T.int32(1)) + + def test_index_map_inverse_no_iter(): def input_example(i0, i1, i2, i3): j0 = floordiv(i3, 32)