Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 105 additions & 60 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
#define TVM_ARITH_ANALYZER_H_

#include <tvm/arith/int_set.h>
#include <tvm/ffi/cast.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/with_context.h>

#include <limits>
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
Expand All @@ -48,8 +50,10 @@ namespace arith {
// another analyzer.
//-------------------------------------------------------

// Forward declare Analyzer
// Forward declare the analyzer object and its reference handle.
class AnalyzerObj;
class Analyzer;
class ConstraintContext;

using tirx::Var;

Expand Down Expand Up @@ -172,9 +176,9 @@ class ConstIntBoundAnalyzer {
TVM_DLL bool IsBound(const Var& var) const;

private:
friend class Analyzer;
friend class AnalyzerObj;
friend class ConstraintContext;
explicit ConstIntBoundAnalyzer(Analyzer* parent);
explicit ConstIntBoundAnalyzer(AnalyzerObj* parent);
TVM_DLL ~ConstIntBoundAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
Expand Down Expand Up @@ -251,9 +255,9 @@ class ModularSetAnalyzer {
TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);

private:
friend class Analyzer;
friend class AnalyzerObj;
friend class ConstraintContext;
explicit ModularSetAnalyzer(Analyzer* parent);
explicit ModularSetAnalyzer(AnalyzerObj* parent);
TVM_DLL ~ModularSetAnalyzer();
/*!
* \brief Update the internal state to enter constraint.
Expand Down Expand Up @@ -397,16 +401,17 @@ class RewriteSimplifier {
* Note: To maintain accurate usage counters, `Analyzer` instances
* should be re-used wherever possible. For example, TIR
* transformations should declare a single `Analyzer` that is used
* throughout the pass, and utility functions should receive an
* `Analyzer*` from their calling scope.
* throughout the pass. Internal helper functions that only borrow
* the analyzer temporarily may receive the underlying `AnalyzerObj*`
* from their calling scope.
*/
TVM_DLL void SetMaximumRewriteSteps(int64_t maximum);

private:
friend class Analyzer;
friend class AnalyzerObj;
friend class ConstraintContext;
friend class CanonicalSimplifier;
explicit RewriteSimplifier(Analyzer* parent);
explicit RewriteSimplifier(AnalyzerObj* parent);
TVM_DLL ~RewriteSimplifier();
class Impl;
/*! \brief Internal impl */
Expand Down Expand Up @@ -435,9 +440,9 @@ class CanonicalSimplifier {
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

private:
friend class Analyzer;
friend class AnalyzerObj;
friend class ConstraintContext;
explicit CanonicalSimplifier(Analyzer* parent);
explicit CanonicalSimplifier(AnalyzerObj* parent);
TVM_DLL ~CanonicalSimplifier();
class Impl;
/*! \brief Internal impl */
Expand Down Expand Up @@ -520,7 +525,7 @@ class TransitiveComparisonAnalyzer {
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
friend class AnalyzerObj;
friend class ConstraintContext;
TransitiveComparisonAnalyzer();
TVM_DLL ~TransitiveComparisonAnalyzer();
Expand All @@ -529,45 +534,6 @@ class TransitiveComparisonAnalyzer {
std::unique_ptr<Impl> impl_;
};

/*!
* \brief Constraint context.
*
* \code
*
* Var("x");
* arith::Analyzer analyzer;
* {
* With<arith::ConstraintContext> scope(&analyzer, x % 3 == 0);
* TVM_FFI_ICHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
* TVM_FFI_ICHECK_NE(analyzer.modular_set(x)->coeff, 3);
*
* \endcode
*/
class ConstraintContext {
private:
// declare friend to enable with.
friend class With<ConstraintContext>;
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
// enter the scope.
void EnterWithScope();
// exit the scope.
void ExitWithScope();
/*! \brief The analyzer */
Analyzer* analyzer_;
/*! \brief The constraint */
PrimExpr constraint_;
/*! \brief functions to be called in recovery */
std::vector<std::function<void()>> recovery_functions_;
};

/*!
* \brief Integer set analyzer.
*/
Expand Down Expand Up @@ -614,8 +580,8 @@ class IntSetAnalyzer {
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
friend class AnalyzerObj;
explicit IntSetAnalyzer(AnalyzerObj* parent);
TVM_DLL ~IntSetAnalyzer();
class Impl;
/*! \brief Internal impl */
Expand All @@ -632,13 +598,8 @@ class IntSetAnalyzer {
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
class TVM_DLL Analyzer {
class TVM_DLL AnalyzerObj : public ffi::Object {
public:
/*
* Disable copy constructor.
*/
Analyzer(const Analyzer&) = delete;
Analyzer& operator=(const Analyzer&) = delete;
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
Expand All @@ -652,7 +613,7 @@ class TVM_DLL Analyzer {
/*! \brief sub-analyzer transitive comparisons */
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief constructor */
Analyzer();
AnalyzerObj();
/*!
* \brief Mark the value as non-negative value globally in analyzer.
*
Expand Down Expand Up @@ -785,6 +746,90 @@ class TVM_DLL Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);

/*!
* \brief Analyzer methods update facts, constraints, caches, and stats.
*
* Marking the object mutable makes the `Analyzer` ObjectRef expose a
* non-const `operator->`, so APIs can take `const Analyzer&` while still
* allowing calls such as `analyzer->Bind(...)`.
* `const Analyzer&` keeps the handle itself from being rebound; it does
* not make the underlying AnalyzerObj immutable.
*/
static constexpr bool _type_mutable = true;
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("arith.Analyzer", AnalyzerObj, ffi::Object);
};

/*!
* \brief Managed reference to AnalyzerObj.
*
* Analyzer is a lightweight, reference-counted handle around a heap-allocated
* AnalyzerObj. Because it is now a first-class FFI object, an Analyzer can be
* passed across the tvm-ffi boundary (e.g. handed from Python into a C++ pass)
* and shared, so that accumulated bindings/constraints persist across calls.
* Copying an Analyzer copies the handle, and both handles share the same
* mutable AnalyzerObj state.
* This is not a deep copy of analyzer facts or caches.
*
* \sa AnalyzerObj
*/
class Analyzer : public ffi::ObjectRef {
public:
/*! \brief Default-construct a fresh analyzer (allocates an AnalyzerObj). */
Analyzer() : Analyzer(ffi::make_object<AnalyzerObj>()) {}
explicit Analyzer(ffi::ObjectPtr<AnalyzerObj> n) : ffi::ObjectRef(std::move(n)) {
TVM_FFI_ICHECK(this->get() != nullptr);
}
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Analyzer, ffi::ObjectRef, AnalyzerObj);
};

/*!
* \brief Constraint context.
*
* \code
*
* Var x("x");
* arith::Analyzer analyzer;
* {
* With<arith::ConstraintContext> scope(analyzer, tvm::floormod(x, 3) == 0);
* TVM_FFI_ICHECK_EQ(analyzer->modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
* TVM_FFI_ICHECK_NE(analyzer->modular_set(x)->coeff, 3);
*
* \endcode
*/
class ConstraintContext {
private:
// declare friend to enable with.
friend class With<ConstraintContext>;
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer whose context is updated. The context
* keeps a reference to the analyzer while the scope is active.
* \param constraint The constraint to be applied.
*/
ConstraintContext(const Analyzer& analyzer, PrimExpr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
/*!
* \brief Construct a constraint context from a borrowed analyzer object.
* \param analyzer The borrowed analyzer object.
* \param constraint The constraint to be applied.
*
* This overload is for internal callers that already operate on AnalyzerObj*.
*/
ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint)
: ConstraintContext(ffi::GetRef<Analyzer>(analyzer), std::move(constraint)) {}
// enter the scope.
void EnterWithScope();
// exit the scope.
void ExitWithScope();
/*! \brief Analyzer kept alive while the context is active. */
Analyzer analyzer_;
/*! \brief The constraint */
PrimExpr constraint_;
/*! \brief functions to be called in recovery */
std::vector<std::function<void()>> recovery_functions_;
};

} // namespace arith
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using tirx::IterVar;
using tirx::Var;
using tirx::VarNode;

class AnalyzerObj;
class Analyzer;

//-----------------------------------------------
Expand Down Expand Up @@ -96,7 +97,7 @@ class IntSet : public ffi::ObjectRef {
* \param ana Analyzer used in the proof.
* \return Whether we can prove it is a single point
*/
bool CanProveSinglePoint(Analyzer* ana) const;
bool CanProveSinglePoint(const Analyzer& ana) const;
// TODO(tvm-team): update all CanProve to explicitly take
// analyzer to encourage more analyzer reuse
/*! \return Whether the set is proved to be bigger than 0 */
Expand Down Expand Up @@ -302,7 +303,7 @@ ffi::Map<Var, arith::IntSet> AsIntSet(const ffi::Map<Var, Range>& var_dom);
*/
TVM_DLL ffi::Optional<ffi::Array<IntSet>> EstimateRegionStrictBound(
const ffi::Array<Range>& region, const ffi::Map<Var, Range>& var_dom, const PrimExpr& predicate,
arith::Analyzer* analyzer);
const arith::Analyzer& analyzer);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate.
Expand All @@ -316,7 +317,7 @@ TVM_DLL ffi::Optional<ffi::Array<IntSet>> EstimateRegionStrictBound(
*/
TVM_DLL ffi::Optional<ffi::Array<IntSet>> EstimateRegionLowerBound(
const ffi::Array<Range>& region, const ffi::Map<Var, Range>& var_dom, const PrimExpr& predicate,
arith::Analyzer* analyzer);
const arith::Analyzer& analyzer);

/*!
* \brief Analyze the region with affine map, given the domain of variables and their predicate
Expand All @@ -331,7 +332,7 @@ TVM_DLL ffi::Optional<ffi::Array<IntSet>> EstimateRegionLowerBound(
TVM_DLL ffi::Array<IntSet> EstimateRegionUpperBound(const ffi::Array<Range>& region,
const ffi::Map<Var, Range>& var_dom,
const PrimExpr& predicate,
arith::Analyzer* analyzer);
const arith::Analyzer& analyzer);

} // namespace arith
} // namespace tvm
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class IterMapResult : public ffi::ObjectRef {
*/
IterMapResult DetectIterMap(const ffi::Array<PrimExpr>& indices,
const ffi::Map<Var, Range>& input_iters, const PrimExpr& predicate,
IterMapLevel check_level, arith::Analyzer* analyzer,
IterMapLevel check_level, const arith::Analyzer& analyzer,
bool simplify_trivial_iterators = true);

/*!
Expand All @@ -323,7 +323,7 @@ IterMapResult DetectIterMap(const ffi::Array<PrimExpr>& indices,
ffi::Array<PrimExpr> IterMapSimplify(const ffi::Array<PrimExpr>& indices,
const ffi::Map<Var, Range>& input_iters,
const PrimExpr& input_pred, IterMapLevel check_level,
arith::Analyzer* analyzer,
const arith::Analyzer& analyzer,
bool simplify_trivial_iterators = true);

/*!
Expand Down Expand Up @@ -380,7 +380,7 @@ ffi::Array<ffi::Array<IterMark>> SubspaceDivide(const ffi::Array<PrimExpr>& bind
const ffi::Map<Var, Range>& input_iters,
const ffi::Array<Var>& sub_iters,
const PrimExpr& predicate, IterMapLevel check_level,
arith::Analyzer* analyzer,
const arith::Analyzer& analyzer,
bool simplify_trivial_iterators = true);

/*!
Expand All @@ -407,7 +407,7 @@ PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr);
* \note This function is useful to detect iterator stride patterns.
*/
IterSumExpr NormalizeToIterSum(PrimExpr index, const ffi::Map<Var, Range>& input_iters,
arith::Analyzer* analyzer);
const arith::Analyzer& analyzer);

} // namespace arith
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/scope_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace tvm {
*
* // In VisitStmt_(ForNode):
* return constraints.WithNewScope([&]() -> Stmt {
* constraints.Current().Emplace(&analyzer, condition);
* constraints.Current().Emplace(analyzer, condition);
* return StmtExprMutator::VisitStmt_(op);
* });
* \endcode
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/with_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ class With {
*
* \code
* WithGroup<ConstraintContext> group;
* group.Emplace(&analyzer, cond1); // constructs and enters
* group.Emplace(&analyzer, cond2); // constructs and enters
* group.Emplace(analyzer, cond1); // constructs and enters
* group.Emplace(analyzer, cond2); // constructs and enters
* // destructor: exits cond2, then cond1
* \endcode
*
Expand Down
Loading
Loading