Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt")
tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF)
tvm_option(USE_BLAS "The blas library to be linked" none)
tvm_option(USE_AMX "Enable Intel AMX" OFF)
tvm_option(USE_Z3 "Build with Z3 SMT solver support" OFF)
tvm_option(USE_MKL "MKL root path when use MKL blas" OFF)
tvm_option(USE_DNNL "Enable DNNL codegen" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
Expand Down Expand Up @@ -447,6 +448,7 @@ include(cmake/modules/contrib/CUTLASS.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Posit.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/Z3.cmake)
include(cmake/modules/contrib/CoreML.cmake)
include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/contrib/NNAPI.cmake)
Expand Down
76 changes: 76 additions & 0 deletions cmake/modules/contrib/Z3.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(NOT USE_Z3)
list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc)
return()
endif()

find_package(Z3 QUIET)
set(Z3_PYTHON_RESULT 1)

if(NOT Z3_FOUND)
find_package(Python3 COMPONENTS Interpreter QUIET)
if(Python3_EXECUTABLE)
execute_process(
COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])"
OUTPUT_VARIABLE Z3_PYTHON_PACKAGE_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE Z3_PYTHON_RESULT
)
endif()

if(Z3_PYTHON_RESULT EQUAL 0 AND NOT Z3_PYTHON_PACKAGE_DIR STREQUAL "")
find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS "${Z3_PYTHON_PACKAGE_DIR}/include")
find_library(
Z3_LIBRARY
NO_DEFAULT_PATH
NAMES z3 libz3
PATHS "${Z3_PYTHON_PACKAGE_DIR}/bin" "${Z3_PYTHON_PACKAGE_DIR}/lib"
"${Z3_PYTHON_PACKAGE_DIR}/lib64"
)
endif()
endif()

if(TARGET z3::libz3 OR TARGET Z3::libz3)
if(TARGET z3::libz3)
set(Z3_TARGET z3::libz3)
else()
set(Z3_TARGET Z3::libz3)
endif()
get_target_property(Z3_TARGET_INCLUDE_DIRS ${Z3_TARGET} INTERFACE_INCLUDE_DIRECTORIES)
if(Z3_TARGET_INCLUDE_DIRS)
include_directories(SYSTEM ${Z3_TARGET_INCLUDE_DIRS})
endif()
list(APPEND TVM_LINKER_LIBS ${Z3_TARGET})
elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY))
if(NOT Z3_INCLUDE_DIR AND Z3_CXX_INCLUDE_DIRS)
set(Z3_INCLUDE_DIR ${Z3_CXX_INCLUDE_DIRS})
endif()
if(NOT Z3_LIBRARY AND Z3_LIBRARIES)
set(Z3_LIBRARY ${Z3_LIBRARIES})
endif()
if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY)
message(FATAL_ERROR "USE_Z3 is ON, but Z3 include directory or library was not found.")
endif()
include_directories(SYSTEM ${Z3_INCLUDE_DIR})
list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY})
else()
message(FATAL_ERROR "USE_Z3 is ON, but Z3 was not found. Install Z3 or PyPI z3-solver.")
endif()

list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc)
109 changes: 106 additions & 3 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/arith/int_set.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/with_context.h>

Expand Down Expand Up @@ -295,7 +296,7 @@ class RewriteSimplifier {
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume = false);

/*! \brief Flags to enable more computationally-intensive simplifications
*
Expand Down Expand Up @@ -554,8 +555,8 @@ class ConstraintContext {
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext(Analyzer* analyzer, PrimExpr constraint)
: analyzer_(analyzer), constraint_(constraint) {}
ConstraintContext(Analyzer* analyzer, PrimExpr constraint, bool is_assume = false)
: analyzer_(analyzer), constraint_(constraint), is_assume_(is_assume) {}
// enter the scope.
void EnterWithScope();
// exit the scope.
Expand All @@ -566,6 +567,7 @@ class ConstraintContext {
PrimExpr constraint_;
/*! \brief functions to be called in recovery */
std::vector<std::function<void()>> recovery_functions_;
bool is_assume_;
};

/*!
Expand Down Expand Up @@ -622,6 +624,103 @@ class IntSetAnalyzer {
Impl* impl_;
};

class Z3Prover {
public:
/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param new_range The range of allowed values for this var.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false);

/*!
* \brief Update binding of var to a new expression.
*
* \param var The variable of interest.
* \param expr The bound expression.
* \param allow_override whether we allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);

/*!
* \brief Whether can we prove expr is always true.
*
* \param expr The expression.
* \return Whether we can prove it.
*/
TVM_DLL bool CanProve(const PrimExpr& expr);

/*!
* \brief Update the internal state to enter constraint.
*
* \param constraint A constraint expression.
* \param is_assume Whether the constraint comes from an assumption.
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume = false);

/*!
* \brief Get the SMTLIB2 representation of the current context.
*
* \param expr The optional expression to check.
* \return The SMTLIB2 string.
*/
ffi::String GetSMTLIB2(const ffi::Optional<PrimExpr> expr);

/*!
* \brief Get statistics about Z3 prover.
*
* \return The statistics string.
*/
ffi::String GetStats();

/*!
* \brief Set timeout in milliseconds for Z3 prover.
*
* \param timeout_ms The timeout in milliseconds.
*/
void SetTimeoutMs(unsigned timeout_ms);

/*!
* \brief Set resource limitation for Z3 prover.
*
* \param rlimit the resource limitation.
*/
void SetRLimit(unsigned rlimit);

/*!
* \brief Get the Z3 model for the given expression if satisfiable.
*
* \param expr The expression to get the model for.
* \return The model as a string.
*/
ffi::String GetModel(const PrimExpr& expr);

/*!
* \brief Count the number of integer values that satisfy the current constraints.
*
* This method uses Z3's model enumeration to count how many distinct values of
* the given variable satisfy all current constraints.
*
* \param var The variable to count satisfying values for.
* \param max_count Maximum number of solutions to enumerate.
* \param min_consecutive Minimum consecutive count requirement.
* \return The number of distinct values that satisfy the constraints, or a negative error code.
*/
TVM_DLL int64_t CountSatisfyingValues(const Var& var, int64_t max_count = 2048,
int64_t min_consecutive = 1);

private:
friend class Analyzer;
explicit Z3Prover(Analyzer* parent);
TVM_DLL ~Z3Prover();
void CopyFrom(const Z3Prover& other);
class Impl;
Impl* impl_;
};

/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
Expand Down Expand Up @@ -651,6 +750,8 @@ class TVM_DLL Analyzer {
IntSetAnalyzer int_set;
/*! \brief sub-analyzer transitive comparisons */
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief sub-analyzer using Z3 */
Z3Prover z3_prover;
/*! \brief constructor */
Analyzer();
/*!
Expand Down Expand Up @@ -785,6 +886,8 @@ class TVM_DLL Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
PrimExpr Simplify(const PrimExpr& expr, int steps = 2);

std::function<void()> EnterConstraint(const PrimExpr& constraint, bool is_assume = false);
};

} // namespace arith
Expand Down
44 changes: 44 additions & 0 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,53 @@ def __init__(self):
self._enter_constraint_context = _mod("enter_constraint_context")
self._can_prove_equal = _mod("can_prove_equal")
self._can_prove = _mod("can_prove")
self._get_smtlib2 = _mod("get_smtlib2")
self._set_z3_timeout_ms = _mod("set_z3_timeout_ms")
self._set_z3_rlimit = _mod("set_z3_rlimit")
self._get_z3_stats = _mod("get_z3_stats")
self._get_enabled_extensions = _mod("get_enabled_extensions")
self._set_enabled_extensions = _mod("set_enabled_extensions")

def get_smtlib2(self, expr: tirx.PrimExpr = None) -> str:
"""Get the current Z3 problem in SMT-LIB2 format.

Parameters
----------
expr : Optional[PrimExpr]
The expression to prove. If provided, its negation is added to the problem.
"""
return self._get_smtlib2(expr)

def set_z3_timeout_ms(self, timeout_ms: int) -> None:
"""Set Z3 timeout in milliseconds.

Parameters
----------
timeout_ms : int
The timeout in milliseconds.
"""
self._set_z3_timeout_ms(timeout_ms)

def set_z3_rlimit(self, rlimit: int) -> None:
"""Set Z3 resource limit.

Parameters
----------
rlimit : int
The resource limit.
"""
self._set_z3_rlimit(rlimit)

def get_z3_stats(self) -> str:
"""Get Z3 solver statistics.

Returns
-------
stats : str
The Z3 statistics.
"""
return self._get_z3_stats()

def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound:
"""Find constant integer bound for expr.

Expand Down
48 changes: 46 additions & 2 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ Analyzer::Analyzer()
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this),
int_set(this) {}
int_set(this),
z3_prover(this) {}

void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
PrimExpr new_expr = expr;
Expand All @@ -52,6 +53,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
this->canonical_simplify.Update(var, new_expr, allow_override);
this->int_set.Update(var, this->int_set(new_expr), allow_override);
this->transitive_comparisons.Bind(var, expr, allow_override);
this->z3_prover.Bind(var, expr, allow_override);
}

void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
Expand All @@ -62,6 +64,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
this->transitive_comparisons.Bind(var, range, allow_override);
this->z3_prover.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
Expand Down Expand Up @@ -128,9 +131,11 @@ void ConstraintContext::EnterWithScope() {
// entering the scope.
recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
recovery_functions_.push_back(
analyzer_->rewrite_simplify.EnterConstraint(constraint_, is_assume_));
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_, is_assume_));
}

void ConstraintContext::ExitWithScope() {
Expand Down Expand Up @@ -230,9 +235,30 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
}
}

if (z3_prover.CanProve(simplified)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Z3 fallback isn't exception-safe — Z3 can throw z3::exception and nothing catches it, so it escapes Analyzer::CanProve and breaks the calling pass. A fallback shouldn't turn a failed proof into a crash. it's good to wrap the body of Z3Prover::CanProve in try { ... } catch (const z3::exception&) { return false; }.

return true;
}
return false;
}

std::function<void()> Analyzer::EnterConstraint(const PrimExpr& constraint, bool is_assume) {
std::vector<std::function<void()>> recovery_functions;
recovery_functions.push_back(this->const_int_bound.EnterConstraint(constraint));
recovery_functions.push_back(this->modular_set.EnterConstraint(constraint));
recovery_functions.push_back(this->rewrite_simplify.EnterConstraint(constraint, is_assume));
recovery_functions.push_back(this->int_set.EnterConstraint(constraint));
recovery_functions.push_back(this->transitive_comparisons.EnterConstraint(constraint));
recovery_functions.push_back(this->z3_prover.EnterConstraint(constraint, is_assume));
return [recovery_functions]() {
for (auto it = recovery_functions.rbegin(); it != recovery_functions.rend(); ++it) {
auto& func = *it;
if (func) {
func();
}
}
};
}

PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
PrimExpr res = expr;

Expand Down Expand Up @@ -345,6 +371,24 @@ TVM_FFI_STATIC_INIT_BLOCK() {
self->rewrite_simplify.SetEnabledExtensions(
static_cast<RewriteSimplifier::Extension>(flags));
});
} else if (name == "get_smtlib2") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
auto expr = args[0].cast<ffi::Optional<PrimExpr>>();
*ret = self->z3_prover.GetSMTLIB2(expr);
});
} else if (name == "get_z3_stats") {
return ffi::Function(
[self](ffi::PackedArgs args, ffi::Any* ret) { *ret = self->z3_prover.GetStats(); });
} else if (name == "set_z3_timeout_ms") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
unsigned timeout_ms = args[0].cast<unsigned>();
self->z3_prover.SetTimeoutMs(timeout_ms);
});
} else if (name == "set_z3_rlimit") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
unsigned rlimit = args[0].cast<unsigned>();
self->z3_prover.SetRLimit(rlimit);
});
}
return ffi::Function();
};
Expand Down
Loading
Loading