From 6fca18cab6b3e44145b8897f9ab91021f2d461ce Mon Sep 17 00:00:00 2001 From: Ubospica Date: Wed, 3 Jun 2026 16:48:20 -0400 Subject: [PATCH] [ARITH] Add optional Z3-backed proving to Analyzer Add an optional Z3 SMT solver backend to tvm::arith::Analyzer for stronger integer arithmetic proving. The integration is guarded by a new USE_Z3 CMake option (default OFF). When enabled, Analyzer::CanProve runs the existing analysis path first and only falls back to Z3 when the existing analyzers cannot prove the predicate. When disabled, a stub implementation keeps the C++ and Python APIs available without Z3. --- CMakeLists.txt | 2 + cmake/modules/contrib/Z3.cmake | 76 +++ include/tvm/arith/analyzer.h | 109 +++- python/tvm/arith/analyzer.py | 44 ++ src/arith/analyzer.cc | 48 +- src/arith/rewrite_simplify.cc | 10 +- src/arith/rewrite_simplify.h | 2 +- src/target/z3/z3_prover_off.cc | 38 ++ src/target/z3/z3_prover_on.cc | 788 ++++++++++++++++++++++++++++ tests/python/arith/test_arith_z3.py | 92 ++++ 10 files changed, 1199 insertions(+), 10 deletions(-) create mode 100644 cmake/modules/contrib/Z3.cmake create mode 100644 src/target/z3/z3_prover_off.cc create mode 100644 src/target/z3/z3_prover_on.cc create mode 100644 tests/python/arith/test_arith_z3.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 9591352e4d45..565b72df10ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_AMX "Enable Intel AMX" OFF) +tvm_option(USE_Z3 "Build with Z3 SMT solver support" OFF) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_DNNL "Enable DNNL codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) @@ -447,6 +448,7 @@ include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/Sort.cmake) +include(cmake/modules/contrib/Z3.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/NNAPI.cmake) diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake new file mode 100644 index 000000000000..eef62e4cfcd8 --- /dev/null +++ b/cmake/modules/contrib/Z3.cmake @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NOT USE_Z3) + list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc) + return() +endif() + +find_package(Z3 QUIET) +set(Z3_PYTHON_RESULT 1) + +if(NOT Z3_FOUND) + find_package(Python3 COMPONENTS Interpreter QUIET) + if(Python3_EXECUTABLE) + execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])" + OUTPUT_VARIABLE Z3_PYTHON_PACKAGE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE Z3_PYTHON_RESULT + ) + endif() + + if(Z3_PYTHON_RESULT EQUAL 0 AND NOT Z3_PYTHON_PACKAGE_DIR STREQUAL "") + find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS "${Z3_PYTHON_PACKAGE_DIR}/include") + find_library( + Z3_LIBRARY + NO_DEFAULT_PATH + NAMES z3 libz3 + PATHS "${Z3_PYTHON_PACKAGE_DIR}/bin" "${Z3_PYTHON_PACKAGE_DIR}/lib" + "${Z3_PYTHON_PACKAGE_DIR}/lib64" + ) + endif() +endif() + +if(TARGET z3::libz3 OR TARGET Z3::libz3) + if(TARGET z3::libz3) + set(Z3_TARGET z3::libz3) + else() + set(Z3_TARGET Z3::libz3) + endif() + get_target_property(Z3_TARGET_INCLUDE_DIRS ${Z3_TARGET} INTERFACE_INCLUDE_DIRECTORIES) + if(Z3_TARGET_INCLUDE_DIRS) + include_directories(SYSTEM ${Z3_TARGET_INCLUDE_DIRS}) + endif() + list(APPEND TVM_LINKER_LIBS ${Z3_TARGET}) +elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY)) + if(NOT Z3_INCLUDE_DIR AND Z3_CXX_INCLUDE_DIRS) + set(Z3_INCLUDE_DIR ${Z3_CXX_INCLUDE_DIRS}) + endif() + if(NOT Z3_LIBRARY AND Z3_LIBRARIES) + set(Z3_LIBRARY ${Z3_LIBRARIES}) + endif() + if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY) + message(FATAL_ERROR "USE_Z3 is ON, but Z3 include directory or library was not found.") + endif() + include_directories(SYSTEM ${Z3_INCLUDE_DIR}) + list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY}) +else() + message(FATAL_ERROR "USE_Z3 is ON, but Z3 was not found. Install Z3 or PyPI z3-solver.") +endif() + +list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3d9e8ebbf93f..b221b506ecca 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -295,7 +296,7 @@ class RewriteSimplifier { * * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - TVM_DLL std::function EnterConstraint(const PrimExpr& constraint); + TVM_DLL std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); /*! \brief Flags to enable more computationally-intensive simplifications * @@ -554,8 +555,8 @@ class ConstraintContext { * \param analyzer The analyzer. * \param constraint The constraint to be applied. */ - ConstraintContext(Analyzer* analyzer, PrimExpr constraint) - : analyzer_(analyzer), constraint_(constraint) {} + ConstraintContext(Analyzer* analyzer, PrimExpr constraint, bool is_assume = false) + : analyzer_(analyzer), constraint_(constraint), is_assume_(is_assume) {} // enter the scope. void EnterWithScope(); // exit the scope. @@ -566,6 +567,7 @@ class ConstraintContext { PrimExpr constraint_; /*! \brief functions to be called in recovery */ std::vector> recovery_functions_; + bool is_assume_; }; /*! @@ -622,6 +624,103 @@ class IntSetAnalyzer { Impl* impl_; }; +class Z3Prover { + public: + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_range The range of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param expr The bound expression. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! + * \brief Whether can we prove expr is always true. + * + * \param expr The expression. + * \return Whether we can prove it. + */ + TVM_DLL bool CanProve(const PrimExpr& expr); + + /*! + * \brief Update the internal state to enter constraint. + * + * \param constraint A constraint expression. + * \param is_assume Whether the constraint comes from an assumption. + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); + + /*! + * \brief Get the SMTLIB2 representation of the current context. + * + * \param expr The optional expression to check. + * \return The SMTLIB2 string. + */ + ffi::String GetSMTLIB2(const ffi::Optional expr); + + /*! + * \brief Get statistics about Z3 prover. + * + * \return The statistics string. + */ + ffi::String GetStats(); + + /*! + * \brief Set timeout in milliseconds for Z3 prover. + * + * \param timeout_ms The timeout in milliseconds. + */ + void SetTimeoutMs(unsigned timeout_ms); + + /*! + * \brief Set resource limitation for Z3 prover. + * + * \param rlimit the resource limitation. + */ + void SetRLimit(unsigned rlimit); + + /*! + * \brief Get the Z3 model for the given expression if satisfiable. + * + * \param expr The expression to get the model for. + * \return The model as a string. + */ + ffi::String GetModel(const PrimExpr& expr); + + /*! + * \brief Count the number of integer values that satisfy the current constraints. + * + * This method uses Z3's model enumeration to count how many distinct values of + * the given variable satisfy all current constraints. + * + * \param var The variable to count satisfying values for. + * \param max_count Maximum number of solutions to enumerate. + * \param min_consecutive Minimum consecutive count requirement. + * \return The number of distinct values that satisfy the constraints, or a negative error code. + */ + TVM_DLL int64_t CountSatisfyingValues(const Var& var, int64_t max_count = 2048, + int64_t min_consecutive = 1); + + private: + friend class Analyzer; + explicit Z3Prover(Analyzer* parent); + TVM_DLL ~Z3Prover(); + void CopyFrom(const Z3Prover& other); + class Impl; + Impl* impl_; +}; + /*! * \brief Analyzer that contains bunch of sub-analyzers. * @@ -651,6 +750,8 @@ class TVM_DLL Analyzer { IntSetAnalyzer int_set; /*! \brief sub-analyzer transitive comparisons */ TransitiveComparisonAnalyzer transitive_comparisons; + /*! \brief sub-analyzer using Z3 */ + Z3Prover z3_prover; /*! \brief constructor */ Analyzer(); /*! @@ -785,6 +886,8 @@ class TVM_DLL Analyzer { * \note Analyzer will call into sub-analyzers to get the result. */ PrimExpr Simplify(const PrimExpr& expr, int steps = 2); + + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); }; } // namespace arith diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index ea70c4de3d0f..67275fbc1ee8 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -123,9 +123,53 @@ def __init__(self): self._enter_constraint_context = _mod("enter_constraint_context") self._can_prove_equal = _mod("can_prove_equal") self._can_prove = _mod("can_prove") + self._get_smtlib2 = _mod("get_smtlib2") + self._set_z3_timeout_ms = _mod("set_z3_timeout_ms") + self._set_z3_rlimit = _mod("set_z3_rlimit") + self._get_z3_stats = _mod("get_z3_stats") self._get_enabled_extensions = _mod("get_enabled_extensions") self._set_enabled_extensions = _mod("set_enabled_extensions") + def get_smtlib2(self, expr: tirx.PrimExpr = None) -> str: + """Get the current Z3 problem in SMT-LIB2 format. + + Parameters + ---------- + expr : Optional[PrimExpr] + The expression to prove. If provided, its negation is added to the problem. + """ + return self._get_smtlib2(expr) + + def set_z3_timeout_ms(self, timeout_ms: int) -> None: + """Set Z3 timeout in milliseconds. + + Parameters + ---------- + timeout_ms : int + The timeout in milliseconds. + """ + self._set_z3_timeout_ms(timeout_ms) + + def set_z3_rlimit(self, rlimit: int) -> None: + """Set Z3 resource limit. + + Parameters + ---------- + rlimit : int + The resource limit. + """ + self._set_z3_rlimit(rlimit) + + def get_z3_stats(self) -> str: + """Get Z3 solver statistics. + + Returns + ------- + stats : str + The Z3 statistics. + """ + return self._get_z3_stats() + def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 38c699692e7f..17f989dbf148 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -39,7 +39,8 @@ Analyzer::Analyzer() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) {} + int_set(this), + z3_prover(this) {} void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { PrimExpr new_expr = expr; @@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { this->canonical_simplify.Update(var, new_expr, allow_override); this->int_set.Update(var, this->int_set(new_expr), allow_override); this->transitive_comparisons.Bind(var, expr, allow_override); + this->z3_prover.Bind(var, expr, allow_override); } void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { @@ -62,6 +64,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { this->const_int_bound.Bind(var, range, allow_override); this->int_set.Bind(var, range, allow_override); this->transitive_comparisons.Bind(var, range, allow_override); + this->z3_prover.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify @@ -128,9 +131,11 @@ void ConstraintContext::EnterWithScope() { // entering the scope. recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); - recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); + recovery_functions_.push_back( + analyzer_->rewrite_simplify.EnterConstraint(constraint_, is_assume_)); recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_, is_assume_)); } void ConstraintContext::ExitWithScope() { @@ -230,9 +235,30 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { } } + if (z3_prover.CanProve(simplified)) { + return true; + } return false; } +std::function Analyzer::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + std::vector> recovery_functions; + recovery_functions.push_back(this->const_int_bound.EnterConstraint(constraint)); + recovery_functions.push_back(this->modular_set.EnterConstraint(constraint)); + recovery_functions.push_back(this->rewrite_simplify.EnterConstraint(constraint, is_assume)); + recovery_functions.push_back(this->int_set.EnterConstraint(constraint)); + recovery_functions.push_back(this->transitive_comparisons.EnterConstraint(constraint)); + recovery_functions.push_back(this->z3_prover.EnterConstraint(constraint, is_assume)); + return [recovery_functions]() { + for (auto it = recovery_functions.rbegin(); it != recovery_functions.rend(); ++it) { + auto& func = *it; + if (func) { + func(); + } + } + }; +} + PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { PrimExpr res = expr; @@ -345,6 +371,24 @@ TVM_FFI_STATIC_INIT_BLOCK() { self->rewrite_simplify.SetEnabledExtensions( static_cast(flags)); }); + } else if (name == "get_smtlib2") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + auto expr = args[0].cast>(); + *ret = self->z3_prover.GetSMTLIB2(expr); + }); + } else if (name == "get_z3_stats") { + return ffi::Function( + [self](ffi::PackedArgs args, ffi::Any* ret) { *ret = self->z3_prover.GetStats(); }); + } else if (name == "set_z3_timeout_ms") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + unsigned timeout_ms = args[0].cast(); + self->z3_prover.SetTimeoutMs(timeout_ms); + }); + } else if (name == "set_z3_rlimit") { + return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) { + unsigned rlimit = args[0].cast(); + self->z3_prover.SetRLimit(rlimit); + }); } return ffi::Function(); }; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ac2939f53063..13718998ca94 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -526,13 +526,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { return ret; } -std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { +std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint, + bool is_assume) { size_t old_literal_size = literal_constraints_.size(); // we will compare the already simplified result with the constraint, // so simplify the constraint as well PrimExpr new_constraint = operator()(constraint); for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) { - if (SideEffect(subconstraint) <= CallEffectKind::kPure) { + if (is_assume || SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; if (subconstraint.dtype().is_bool()) { @@ -2440,8 +2441,9 @@ void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_ impl_->Update(var, info, allow_override); } -std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { - return impl_->EnterConstraint(constraint); +std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint, + bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); } void RewriteSimplifier::SetEnabledExtensions(Extension flags) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 5f2af7b81705..54168a4e4627 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -117,7 +117,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CastNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; - std::function EnterConstraint(const PrimExpr& constraint); + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); /*! \brief Enable an optional extension or extensions * diff --git a/src/target/z3/z3_prover_off.cc b/src/target/z3/z3_prover_off.cc new file mode 100644 index 000000000000..6c30323fa952 --- /dev/null +++ b/src/target/z3/z3_prover_off.cc @@ -0,0 +1,38 @@ +#include +#include +#include + +#include "tvm/arith/analyzer.h" +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" + +namespace tvm::arith { + +using namespace tirx; +using namespace ffi; + +class Z3Prover::Impl {}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return false; } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) {} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return []() {}; +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + return "; Z3 Prover is disabled."; +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {} +void Z3Prover::SetRLimit(unsigned rlimit) {} +ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return "; Z3 Prover is disabled."; } +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, + int64_t min_consecutive) { + return -1; // Z3 disabled, return error +} + +void Z3Prover::CopyFrom(const Z3Prover& other) {} +ffi::String Z3Prover::GetStats() { return "; Z3 Prover is disabled."; } +Z3Prover::Z3Prover(Analyzer*) : impl_(nullptr) {} +TVM_DLL Z3Prover::~Z3Prover() {} + +} // namespace tvm::arith diff --git a/src/target/z3/z3_prover_on.cc b/src/target/z3/z3_prover_on.cc new file mode 100644 index 000000000000..4942bde94269 --- /dev/null +++ b/src/target/z3/z3_prover_on.cc @@ -0,0 +1,788 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tvm/arith/analyzer.h" +#include "tvm/ffi/cast.h" +#include "tvm/ffi/object.h" +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" +#include "tvm/runtime/data_type.h" +#include "z3++.h" + +namespace tvm::arith { + +using namespace tirx; +using namespace ffi; + +namespace { + +struct Namespace { + std::unordered_set used_names; + /// @brief Get a new name that is not used before + /// This function is used to generate z3 variable names + /// + /// Z3 may deduplicate variables with the same name, which + /// causes issues when different TVM variables are mapped to + /// the same z3 variable. + /// + /// This function generates unique names by appending + /// suffixes to the original expression string representation. + /// + /// such as : "x", "x$1", "x$2", ... + std::string GetNewName(const PrimExpr& expr) { + std::stringstream ss; + ss << expr; + auto name = ss.str(); + if (used_names.count(name) == 0) { + used_names.insert(name); + return name; + } + int idx = 1; + std::string check_name = name + "$" + std::to_string(idx); + while (used_names.count(check_name)) { + idx++; + check_name = name + "$" + std::to_string(idx); + } + used_names.insert(check_name); + return check_name; + } +}; + +} // namespace + +class Z3Prover::Impl : ExprFunctor { + public: + using Base = ExprFunctor; + using Self = Z3Prover::Impl; + + Analyzer* analyzer; + /// @brief Z3 context, a shared ptr, because tilelang want to copy the Analyzer + // We use a thread_local static Z3 context so all analyzers within the same thread + // can share a common context, because Z3 initialization is slow on some CPUs + // (e.g., AMD EPYC 7502 32-Core). Using thread_local ensures thread safety. + inline static thread_local std::shared_ptr ctx{new z3::context()}; + + /// @brief Z3 solver instance + z3::solver solver{*ctx}; + + /// @brief Memorize pure expressions + std::unordered_map memo_; + + bool is_assume = false; + + /// @brief Namespace for variable naming + Namespace ns; + + /// @brief Timeout in milliseconds + unsigned timeout_ms{UINT_MAX}; + + /// @brief Max steps + unsigned rlimit{UINT_MAX}; + + /// @brief Create a z3 solver with custom options + static z3::solver CreateSolver(z3::context& ctx) { + z3::solver solver(ctx); + // here we disable model generation to speed up the solving process + solver.set("model", false); + // ensure determinstic behavior + solver.set("random_seed", (unsigned)42); + return solver; + } + + Impl(Analyzer* parent) : analyzer(parent) { + scope_stack_.push_back({}); + solver = CreateSolver(*ctx); + // default timeout 5ms + // Z3's implementation of timeout, when setting timeout T ms, it will stop at T - 1 ms + // SetTimeoutMs(5); + // use rlimit, not timeout to ensure determinstic behavior + SetRLimit(1e4); + } + + /// @brief Create a Free z3 expression from PrimExprNode + z3::expr Create(const PrimExprNode* op) { + auto ref = ffi::GetRef(op); + auto dtype = op->dtype; + std::string name = ns.GetNewName(ref); + /// TVM max_val can't handle uint64 max correctly, so we special case it here + if (dtype.is_bool()) { + return ctx->bool_const(name.c_str()); + } else { + z3::expr e = ctx->int_const(name.c_str()); + if (dtype.is_uint() && dtype.bits() == 64) { + solver.add(ctx->int_val(0) <= e && e <= ctx->int_val((uint64_t)UINT64_MAX)); + } else { + auto min_val = Downcast(min_value(dtype))->value; + auto max_val = Downcast(max_value(dtype))->value; + solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val)); + } + return e; + } + } + + struct Scope { + enum Kind { + BindValue, + BindRange, + Constraint, + } kind; + Var var; + PrimExpr value; + PrimExpr min; + PrimExpr extent; + PrimExpr constraint; + }; + + /// @brief scope_stack memorizes existing constraint and bindings + /// to generate SMTLIB2 representation with comments + std::vector> scope_stack_; + + /// @brief Enter a constraint scope + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false) { + scope_stack_.push_back({}); + scope_stack_.back().push_back( + Scope{Scope::Constraint, Var(), PrimExpr(), PrimExpr(), PrimExpr(), constraint}); + solver.push(); + this->is_assume = is_assume; + solver.add(VisitBool(constraint)); + this->is_assume = false; + auto side_effect_exprs = std::move(side_effect_exprs_); + side_effect_exprs_.clear(); + if (is_assume) { + return [this, side_effect_exprs]() { + solver.pop(); + for (const auto& expr : side_effect_exprs) { + memo_.erase(expr); + } + scope_stack_.pop_back(); + }; + } else { + for (const auto& expr : side_effect_exprs) { + memo_.erase(expr); + } + return [this]() { + solver.pop(); + scope_stack_.pop_back(); + }; + } + } + + /// @brief Check trivil bad cases, return true if the expr is a bad case + /// Z3 prover may take a long time to initialize (at least 200us), + /// This optimization can speedup 30% of the test cases in our unit tests + bool CheckTrivilBadCases(const PrimExpr& expr) { + if (IsFreeNode(expr)) { + return true; + } + auto checkTrivilCmp = [this](const PrimExpr& lhs, const PrimExpr& rhs) { + if (IsFreeNode(lhs) && rhs->IsInstance()) { + return true; + } + if (IsFreeNode(rhs) && lhs->IsInstance()) { + return true; + } + if (IsFreeNode(lhs) && IsFreeNode(rhs)) { + return true; + } + // cast('xxx', free_var) == constant + if (auto cast = lhs.as()) { + if (IsFreeNode(cast->value) && rhs->IsInstance()) { + return true; + } + } + // constant == cast('xxx', free_var) + if (auto cast = rhs.as()) { + if (IsFreeNode(cast->value) && lhs->IsInstance()) { + return true; + } + } + return false; + }; + if (auto eq = expr.as()) { + auto lhs = eq->a; + auto rhs = eq->b; + return checkTrivilCmp(lhs, rhs); + } else if (auto ne = expr.as()) { + auto lhs = ne->a; + auto rhs = ne->b; + return checkTrivilCmp(lhs, rhs); + } + return false; + } + + /// @brief Check if the expression can be proved + bool CanProve(const PrimExpr& expr) { + if (CheckTrivilBadCases(expr)) return false; + if (!IsValidDType(expr->dtype)) return false; + z3::expr_vector constr(*ctx); + constr.push_back(!ConvertBool(expr)); + auto result = solver.check(constr); + constr.pop_back(); + return result == z3::unsat; + } + + /// @brief Binded + /// @brief Bind a variable to a value or a range + void Bind(const Var& var, const PrimExpr& value, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back(Scope{Scope::BindValue, var, value}); + // we add the binding whenever the value is pure, + // because non-pure parts are handling by creating free variables in VisitExpr + memo_.emplace(var, ConvertInt(value)); + } + + /// @brief Bind a variable to a range + void Bind(const Var& var, const Range& range, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back( + Scope{Scope::BindRange, var, PrimExpr(), range->min, range->extent}); + // 1. Create a placeholder for the var, and save it in the memo + // if the var is overrided later, we can just update the memo, and the old placeholder will + // be ignored + auto var_expr = Create(var.as()); + memo_.emplace(var, var_expr); + + // 2. Add constraint on the placeholder + // when min_expr >= max_expr, the range is empty, which is under undefined behavior + // instead of adding an unsat constraint, we just skip the range constraint to leave it a + // free var + if (tirx::is_const_int(range->min) && tirx::is_const_int(range->min + range->extent)) { + int64_t min_value = *tirx::as_const_int(range->min); + int64_t max_value = *tirx::as_const_int(range->min + range->extent); + if (min_value < max_value) { + solver.add(ctx->int_val(min_value) <= var_expr); + solver.add(var_expr < ctx->int_val(max_value)); + } + } else { + solver.add(ConvertBool(range->extent <= 0 || + (range->min <= var && var < range->min + range->extent))); + } + } + + void CopyFrom(const Self& other_) { + // 1. create a new solver + // because this->solver depends on this->ctx + // we need to deconstruct the old solver, and create a new one depending on other_.ctx + solver = CreateSolver(*other_.ctx); + // 2. copy the context + // the context is a shared_ptr, we can just copy the pointer + ctx = other_.ctx; + // 3. copy other objects + ns = other_.ns; + for (auto& item : other_.memo_) { + memo_.emplace(item.first, item.second); + } + for (auto a : other_.solver.assertions()) { + solver.add(a); + } + // 4. copy timeout options + // but other solver options are not copied + SetTimeoutMs(other_.timeout_ms); + SetRLimit(other_.rlimit); + // 5. copy the scope stack, which containing comments for SMTLIB2 generation + scope_stack_ = other_.scope_stack_; + } + + /// @brief Set timeout in milliseconds + void SetTimeoutMs(unsigned timeout_ms) { + this->timeout_ms = timeout_ms; + solver.set("timeout", timeout_ms); + } + + /// @brief Set max steps + void SetRLimit(unsigned rlimit) { + this->rlimit = rlimit; + solver.set("rlimit", rlimit); + } + + /// @brief Get the SMTLIB2 representation of the current solver state + ffi::String GetSMTLIB2() { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << solver.to_smt2(); + return ss.str(); + } + + void AddScopeDebugMsg(std::ostream& ss) { + for (const auto& scope : scope_stack_) { + ss << "; Entering Scope\n"; + for (const auto& s : scope) { + switch (s.kind) { + case Scope::Constraint: + ss << "; constraint: " << s.constraint << "\n"; + break; + case Scope::BindValue: + ss << "; bind value: " << s.var << " = " << s.value << "\n"; + break; + case Scope::BindRange: + ss << "; bind range: " << s.var << " in [" << s.min << ", " << s.min + s.extent + << ")\n"; + break; + } + } + } + } + + /// @brief Get the SMTLIB2 representation of the current solver state with additional expr trying + /// to prove + ffi::String GetSMTLIB2(const PrimExpr& expr) { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << "; Trying to prove: " << expr << "\n"; + solver.push(); + solver.add(!ConvertBool(expr)); + ss << solver.to_smt2(); + solver.pop(); + return ss.str(); + } + + /// @brief Get the statistics of the solver + ffi::String GetStats() { + std::stringstream ss; + ss << solver.statistics(); + return ss.str(); + } + + ffi::String GetModel(const PrimExpr& expr) { + solver.set("model", true); + solver.push(); + solver.add(!ConvertBool(expr)); + auto result = solver.check(); + ffi::String model_str; + if (result == z3::sat) { + z3::model m = solver.get_model(); + std::map model_map; + for (unsigned i = 0; i < m.size(); i++) { + z3::func_decl d = m[i]; + model_map.emplace(d.name().str(), m.get_const_interp(d)); + } + std::stringstream ss; + for (const auto& [k, v] : model_map) { + ss << " " << k << " = " << v << "\n"; + } + model_str = ss.str(); + } + solver.pop(); + solver.set("model", false); + return model_str; + } + + /*! + * \brief Count the number of distinct integer values satisfying current constraints. + * + * Uses Z3's model enumeration (AllSAT pattern) to count solutions: + * 1. Find a satisfying assignment + * 2. Add a blocking clause to exclude it + * 3. Repeat until UNSAT + * + * \param var The variable to count values for + * \param max_count Safety limit on enumeration + * \param min_consecutive Minimum consecutive count requirement (0 to disable) + * \return Number of satisfying values, -1 on error, -2 if min_consecutive constraint not met + */ + int64_t CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive = 1) { + if (!IsValidDType(var->dtype)) { + return -1; + } + + solver.set("model", true); + solver.push(); + + // Convert the TVM variable to Z3 expression + z3::expr z3_var = VisitInt(var); + + int64_t count = 0; + std::vector found_values; + + while (count < max_count) { + auto result = solver.check(); + if (result != z3::sat) { + break; // No more solutions + } + + z3::model m = solver.get_model(); + z3::expr val_expr = m.eval(z3_var, true); + + // Extract the integer value from Z3 expression + int64_t val; + if (val_expr.is_numeral()) { + val = val_expr.get_numeral_int64(); + } else { + // If we can't get a concrete value, stop enumeration + break; + } + + found_values.push_back(val); + count++; + + // Add blocking clause: var != val (exclude this solution) + solver.add(z3_var != ctx->int_val(val)); + } + + solver.pop(); + solver.set("model", false); + + // Clear any side effects from visiting the variable + for (const auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + + // Check minimum consecutive constraint if enabled + if (min_consecutive > 0 && count > 0) { + // Sort the values to check consecutive groups + std::sort(found_values.begin(), found_values.end()); + + // Check that all values form groups of at least min_consecutive consecutive numbers + int64_t consecutive_count = 1; + for (size_t i = 1; i < found_values.size(); i++) { + if (found_values[i] == found_values[i - 1] + 1) { + // Consecutive value + consecutive_count++; + } else { + // Gap found, check if the previous group meets the minimum + if (consecutive_count < min_consecutive) { + return -2; // Previous group too small + } + consecutive_count = 1; // Start new group + } + } + // Check the last group + if (consecutive_count < min_consecutive) { + return -2; // Last group too small + } + } + + return count; + } + + private: + using Z3BinOp = z3::expr (*)(const z3::expr&, const z3::expr&); + + std::vector side_effect_exprs_; + + z3::expr ConvertBool(const PrimExpr& e, bool is_assume = false) { + this->is_assume = is_assume; + auto res = VisitBool(e); + for (auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + this->is_assume = false; + return res; + } + + z3::expr ConvertInt(const PrimExpr& e, bool is_assume = false) { + this->is_assume = is_assume; + auto res = VisitInt(e); + for (auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + this->is_assume = false; + return res; + } + + /// @brief Visit expression with memoization + z3::expr VisitExpr(const PrimExpr& e) override { + if (memo_.count(e)) { + return memo_.at(e); + } + auto res = Base::VisitExpr(e); + auto side_effect = SideEffect(e); + if (side_effect <= CallEffectKind::kPure) { + memo_.emplace(e, res); + } else if (side_effect <= CallEffectKind::kReadState) { + memo_.emplace(e, res); + side_effect_exprs_.emplace_back(e); + } else { + if (is_assume) { + memo_.emplace(e, res); + } + side_effect_exprs_.emplace_back(e); + } + return res; + } + + /// @brief Check if the expression is a free node having no constraints + bool IsFreeNode(const PrimExpr& e) { + if (memo_.count(e)) { + return false; + } + return e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || + (e->IsInstance() && !IsValidDType(Downcast(e)->value->dtype)); + } + + /// @brief Check if the dtype is valid for z3 integer operations + static bool IsValidDType(const DataType& dtype) { + return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) && dtype.lanes() == 1; + } + + /// @brief Visit the expression and convert it into z3 integer expression + z3::expr VisitInt(const PrimExpr& expr) { + auto e = VisitExpr(expr); + if (e.is_bool()) { + return z3::ite(e, ctx->int_val(1), ctx->int_val(0)); + } else { + return e; + } + } + + /// @brief Visit the expression and convert it into z3 boolean expression + z3::expr VisitBool(const PrimExpr& e) { + auto expr = VisitExpr(e); + if (expr.is_bool()) { + return expr; + } else { + return expr != ctx->int_val(0); + } + } + + /// @brief Helper function to visit binary arithmetic operations + z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode* op, const PrimExpr& a, + const PrimExpr& b) { + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return signed_op(VisitInt(a), VisitInt(b)); + } else { + return Create(op); + } + } + + z3::expr VisitExpr_(const LetNode* op) override { + if (IsValidDType(op->var->dtype)) { + memo_.emplace(op->var, VisitInt(op->value)); + } + return VisitExpr(op->body); + } + z3::expr VisitExpr_(const CastNode* op) override { + // if the inner dtype is valid, we just visit it + if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) { + return VisitInt(op->value); + } else { + // otherwise, we create a new free z3 variable + return Create(op); + } + } + z3::expr VisitExpr_(const VarNode* op) override { return Create(op); } + z3::expr VisitExpr_(const BufferLoadNode* op) override { return Create(op); } + z3::expr VisitExpr_(const ProducerLoadNode* op) override { return Create(op); } + z3::expr VisitExpr_(const ReduceNode* op) override { return Create(op); } + z3::expr VisitExpr_(const MinNode* op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a < b, a, b); + } + z3::expr VisitExpr_(const MaxNode* op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a > b, a, b); + } + static z3::expr floordiv(const z3::expr& a, const z3::expr& b) { + return z3::ite(b > 0, a / b, -((-a) / b)); + } + static z3::expr floormod(const z3::expr& a, const z3::expr& b) { + return z3::ite(b > 0, a % b, -((-a) % b)); + } + z3::expr VisitExpr_(const AddNode* op) override { + return VisitArith(z3::operator+, op, op->a, op->b); + } + z3::expr VisitExpr_(const SubNode* op) override { + return VisitArith(z3::operator-, op, op->a, op->b); + } + z3::expr VisitExpr_(const MulNode* op) override { + return VisitArith(z3::operator*, op, op->a, op->b); + } + z3::expr VisitExpr_(const DivNode* op) override { + return VisitArith(z3::operator/, op, op->a, op->b); + } + z3::expr VisitExpr_(const ModNode* op) override { + return VisitArith(z3::operator%, op, op->a, op->b); + } + z3::expr VisitExpr_(const FloorDivNode* op) override { + return VisitArith(floordiv, op, op->a, op->b); + } + z3::expr VisitExpr_(const FloorModNode* op) override { + return VisitArith(floormod, op, op->a, op->b); + } + z3::expr VisitExpr_(const EQNode* op) override { + return VisitArith(z3::operator==, op, op->a, op->b); + } + z3::expr VisitExpr_(const NENode* op) override { + return VisitArith(z3::operator!=, op, op->a, op->b); + } + z3::expr VisitExpr_(const LTNode* op) override { + return VisitArith(z3::operator<, op, op->a, op->b); + } + z3::expr VisitExpr_(const LENode* op) override { + return VisitArith(z3::operator<=, op, op->a, op->b); + } + z3::expr VisitExpr_(const GTNode* op) override { + return VisitArith(z3::operator>, op, op->a, op->b); + } + z3::expr VisitExpr_(const GENode* op) override { + return VisitArith(z3::operator>=, op, op->a, op->b); + } + z3::expr VisitExpr_(const AndNode* op) override { return VisitBool(op->a) && VisitBool(op->b); } + z3::expr VisitExpr_(const OrNode* op) override { return VisitBool(op->a) || VisitBool(op->b); } + z3::expr VisitExpr_(const NotNode* op) override { return !VisitBool(op->a); } + z3::expr VisitExpr_(const SelectNode* op) override { + return z3::ite(VisitBool(op->condition), VisitInt(op->true_value), VisitInt(op->false_value)); + } + z3::expr VisitExpr_(const IntImmNode* op) override { return ctx->int_val(op->value); } + + // Bitwise operations + z3::expr VisitExpr_(const CallNode* op) override { + // Check if this is a bitwise operation + if (op->op.same_as(tirx::builtin::bitwise_and())) { + return VisitBitwiseOp(z3::operator&, op); + } else if (op->op.same_as(tirx::builtin::bitwise_or())) { + return VisitBitwiseOp(z3::operator|, op); + } else if (op->op.same_as(tirx::builtin::bitwise_xor())) { + return VisitBitwiseOp(z3::operator^, op); + } else if (op->op.same_as(tirx::builtin::bitwise_not())) { + return VisitBitwiseNotOp(op); + } else if (op->op.same_as(tirx::builtin::shift_left())) { + return VisitShiftOp(z3::shl, op); + } else if (op->op.same_as(tirx::builtin::shift_right())) { + return VisitShiftOp(z3::ashr, op); + } else { + // For other call nodes, create a free variable + return Create(op); + } + } + + /// @brief Helper function to visit binary bitwise operations + z3::expr VisitBitwiseOp(z3::expr (*op_func)(const z3::expr&, const z3::expr&), + const CallNode* op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Binary bitwise operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + const PrimExpr& b = op->args[1]; + unsigned bit_width = std::max(op->args[0].dtype().bits(), op->args[1].dtype().bits()); + + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return z3::bv2int( + op_func(z3::int2bv(bit_width, VisitInt(a)), z3::int2bv(bit_width, VisitInt(b))), true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit unary bitwise not operation + z3::expr VisitBitwiseNotOp(const CallNode* op) { + if (op->args.size() != 1) { + LOG(FATAL) << "Bitwise not operation expects 1 argument, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + + if (IsValidDType(a->dtype)) { + // Cast integer to bit-vector, apply bitwise not, then cast back. + unsigned bit_width = a.dtype().bits(); + z3::expr a_int = VisitInt(a); + z3::expr a_bv = z3::int2bv(bit_width, a_int); + return z3::bv2int(~a_bv, true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit shift operations + z3::expr VisitShiftOp(z3::expr (*op_func)(const z3::expr&, const z3::expr&), const CallNode* op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Shift operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + const PrimExpr& b = op->args[1]; + + // Shift operations require integer types for both operands + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + // For shift operations, we need to ensure the shift amount is non-negative + // and within reasonable bounds + z3::expr a_expr = VisitInt(a); + z3::expr b_expr = VisitInt(b); + + // Add constraint that shift amount should be non-negative + // This is a common assumption in many programming languages + solver.add(b_expr >= 0); + + // Also limit shift amount to avoid unrealistic large shifts + // We'll limit to 64 bits (reasonable for most use cases) + solver.add(b_expr < 64); + + unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); + z3::expr a_bv = z3::int2bv(bit_width, a_expr); + z3::expr b_bv = z3::int2bv(bit_width, b_expr); + + // Perform the shift in bit-vector domain, then cast back to int. + z3::expr result_bv = op_func(a_bv, b_bv); + return z3::bv2int(result_bv, true); + } else { + return Create(op); + } + } + + z3::expr VisitExprDefault_(const Object* op) override { + LOG(FATAL) << "Z3Prover only support integers, but got " << op->GetTypeKey() << "."; + TVM_FFI_UNREACHABLE(); + } +}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return impl_->CanProve(expr); } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) { + return impl_->Bind(var, new_range, allow_override); +} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + return impl_->Bind(var, expr, allow_override); +} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + if (expr.has_value()) { + return impl_->GetSMTLIB2(expr.value()); + } else { + return impl_->GetSMTLIB2(); + } +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) { impl_->SetTimeoutMs(timeout_ms); } +void Z3Prover::SetRLimit(unsigned max_step) { impl_->SetRLimit(max_step); } +void Z3Prover::CopyFrom(const Z3Prover& other) { impl_->CopyFrom(*other.impl_); } +ffi::String Z3Prover::GetStats() { return impl_->GetStats(); } +ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return impl_->GetModel(expr); } +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, + int64_t min_consecutive) { + return impl_->CountSatisfyingValues(var, max_count, min_consecutive); +} +Z3Prover::Z3Prover(Analyzer* parent) : impl_(new Impl{parent}) {} +TVM_DLL Z3Prover::~Z3Prover() { delete impl_; } + +} // namespace tvm::arith diff --git a/tests/python/arith/test_arith_z3.py b/tests/python/arith/test_arith_z3.py new file mode 100644 index 000000000000..c638341f4cd1 --- /dev/null +++ b/tests/python/arith/test_arith_z3.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +import tvm.testing +from tvm import tirx + + +def _require_z3(analyzer): + if "Z3 Prover is disabled" in analyzer.get_smtlib2(): + pytest.skip("Z3 prover is disabled in this build") + + +def test_z3_disabled_api_is_available(): + analyzer = tvm.arith.Analyzer() + assert isinstance(analyzer.get_smtlib2(), str) + assert isinstance(analyzer.get_z3_stats(), str) + + +def test_z3_proves_floor_division_identity(): + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + expr = ((b - a) // c) * c + a <= b + assert analyzer.can_prove(expr) + + +def test_z3_bind_range(): + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + + analyzer.bind(a, tvm.ir.Range(1, 100000)) + analyzer.bind(b, tvm.ir.Range(1, 100000)) + analyzer.bind(c, tvm.ir.Range(1, 100000)) + + expr = ((b - a) // c) * c + a <= b + assert analyzer.can_prove(expr) + + +def test_z3_smtlib2_roundtrip(): + z3 = pytest.importorskip("z3") + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = ((b - a) // c) * c + a <= b + + solver = z3.Solver() + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + solver.from_string(analyzer.get_smtlib2(expr)) + assert solver.check() == z3.unsat + + +def test_z3_bitwise(): + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + + assert analyzer.can_prove(tirx.bitwise_and(x, tirx.IntImm("int32", 7)) < 8) + + +if __name__ == "__main__": + tvm.testing.main()