Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .tmp_scatter_cuda_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import numpy as np
import torch
from torch import nn
from torch.export import export
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
from tvm.relax.backend.cuda import get_default_pipeline

class ScatterValue(nn.Module):
def forward(self, x, index):
return x.scatter(1, index, 0.5)

torch.manual_seed(0)
x = torch.randn(4, 8, dtype=torch.float32)
idx = torch.randint(0, 8, (4, 2), dtype=torch.int64)

mod = from_exported_program(export(ScatterValue(), args=(x, idx)))
tgt = tvm.target.Target('cuda')
with tgt:
mod = get_default_pipeline(tgt)(mod)

ex = relax.build(mod, tgt, relax_pipeline=None)
vm = relax.VirtualMachine(ex, tvm.cuda(0))
out = vm['main'](
tvm.runtime.tensor(x.numpy(), device=tvm.cuda(0)),
tvm.runtime.tensor(idx.numpy(), device=tvm.cuda(0)),
)
out_np = out.numpy() if hasattr(out, 'numpy') else out[0].numpy()
ref_np = ScatterValue()(x, idx).numpy()

print('shape_match', out_np.shape == ref_np.shape)
print('allclose', np.allclose(out_np, ref_np, rtol=1e-5, atol=1e-6))
print('max_abs_diff', float(np.max(np.abs(out_np - ref_np))))
7 changes: 7 additions & 0 deletions include/tvm/tirx/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,12 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner();
*/
TVM_DLL Pass BindTarget(Target target);

/*!
* \brief Convert ForKind::kParallel loops to blockIdx.x/threadIdx.x bindings on GPU targets.
* \return The pass.
*/
TVM_DLL Pass BindParallelLoopsToThreads();

/*!
* \brief Set a PrimFunc as the entry point if it is only function in IRModule.
* \return The pass.
Expand All @@ -354,6 +360,7 @@ TVM_DLL Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond);
*
* \return The pass.
*/

} // namespace transform
} // namespace tirx
} // namespace tvm
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/contrib/nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,27 @@ def _compile_cuda_nvcc(
else:
raise ValueError("options must be str or list of str")

# Optional workaround for NVCC host compiler version checks on Windows.
# Priority:
# 1) PassContext config: cuda.nvcc_allow_unsupported_compiler (bool)
# 2) Environment variable: TVM_CUDA_ALLOW_UNSUPPORTED_COMPILER in {"1","true","on","yes"}
# 3) Default: False
allow_unsupported_compiler = False
if "cuda.nvcc_allow_unsupported_compiler" in pass_context.config:
allow_unsupported_compiler = bool(
pass_context.config["cuda.nvcc_allow_unsupported_compiler"]
)
else:
env_val = os.environ.get("TVM_CUDA_ALLOW_UNSUPPORTED_COMPILER", "").strip().lower()
allow_unsupported_compiler = env_val in {"1", "true", "on", "yes"}

if (
platform.system() == "Windows"
and allow_unsupported_compiler
and "-allow-unsupported-compiler" not in cmd
):
cmd += ["-allow-unsupported-compiler"]

cmd += ["-o", file_target]
if not use_nvshmem:
cmd += [temp_code]
Expand Down
1 change: 1 addition & 0 deletions python/tvm/s_tir/backend/adreno/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
# VerifyVTCMLimit must occur before LowerVtcmAlloc.
s_tir.transform.VerifyVTCMLimit(),
s_tir.transform.LowerVtcmAlloc(),
tirx.transform.BindParallelLoopsToThreads(),
tirx.transform.VerifyMemory(),
tirx.transform.AnnotateEntryFunc(),
]
Expand Down
1 change: 1 addition & 0 deletions python/tvm/s_tir/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
# VerifyVTCMLimit must occur before LowerVtcmAlloc.
s_tir.transform.VerifyVTCMLimit(),
s_tir.transform.LowerVtcmAlloc(),
tirx.transform.BindParallelLoopsToThreads(),
tirx.transform.VerifyMemory(),
tirx.transform.AnnotateEntryFunc(),
]
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tirx/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,17 @@ def VerifyMemory():
return _ffi_api.VerifyMemory() # type: ignore


def BindParallelLoopsToThreads():
"""Convert T.parallel loops to block/thread bindings for GPU PrimFuncs.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.BindParallelLoopsToThreads() # type: ignore


@_ffi.register_object("s_tir.transform.HoistIfThenElseConfig")
class HoistIfThenElseConfig(_ir.Attrs):
"""Config for hoist if then else pass"""
Expand Down
2 changes: 1 addition & 1 deletion src/s_tir/transform/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
!no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
return Substitute(body, {{for_node->loop_var, make_const(DataType::Int(32), 0)}});
} else {
TVM_FFI_ICHECK(for_node->kind != ForKind::kThreadBinding);
auto new_loop = ffi::make_object<ForNode>(*for_node);
Expand Down
1 change: 1 addition & 0 deletions src/target/opt/build_cuda_on.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,6 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("target.build.cuda", BuildCUDA);
}
TVM_REGISTER_PASS_CONFIG_OPTION("cuda.kernels_output_dir", ffi::String);
TVM_REGISTER_PASS_CONFIG_OPTION("cuda.nvcc_allow_unsupported_compiler", Bool);
} // namespace codegen
} // namespace tvm
4 changes: 2 additions & 2 deletions src/tirx/analysis/verify_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
void Run() {
if (!IsGPUDevice(dev_type_)) return;
StmtExprVisitor::VisitStmt(func_->body);
}
}

/// Verification result
std::vector<ffi::String> Errors() const { return errs_; }
Expand Down Expand Up @@ -150,7 +150,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
static bool IsGPUDevice(int dev_type) {
return kDLCUDA == dev_type || kDLOpenCL == dev_type || kDLVulkan == dev_type ||
kDLMetal == dev_type || kDLROCM == dev_type;
kDLMetal == dev_type || kDLROCM == dev_type || kDLWebGPU == dev_type;
}

private:
Expand Down
146 changes: 146 additions & 0 deletions src/tirx/transform/bind_parallel_loops_to_threads.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file bind_parallel_loops_to_threads.cc
* \brief Convert ForKind::kParallel loops to GPU thread bindings.
*/

#include <tvm/ffi/reflection/registry.h>
#include <tvm/s_tir/stmt.h>
#include <tvm/target/target.h>
#include <tvm/tirx/op.h>
#include <tvm/tirx/stmt.h>
#include <tvm/tirx/stmt_functor.h>
#include <tvm/tirx/transform.h>

namespace tvm {
namespace tirx {
namespace {

static bool IsGpuDeviceType(int dev_type) {
return dev_type == kDLCUDA || dev_type == kDLROCM || dev_type == kDLOpenCL ||
dev_type == kDLVulkan || dev_type == kDLMetal || dev_type == kDLWebGPU;
}
Comment thread
zhils marked this conversation as resolved.

class ParallelLoopToThreadBindingMutator : public StmtExprMutator {
public:
explicit ParallelLoopToThreadBindingMutator(int64_t max_threads_per_block)
: max_threads_per_block_(max_threads_per_block) {}

private:
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) {
bool prev = in_thread_env_;
in_thread_env_ = true;
Stmt ret = StmtExprMutator::VisitStmt_(op);
in_thread_env_ = prev;
return ret;
}
return StmtExprMutator::VisitStmt_(op);
}

Stmt TransformParallelFor(const ForNode* for_node) {
if (in_thread_env_) {
return ffi::GetRef<Stmt>(for_node);
}

DataType dtype = for_node->loop_var.dtype();
PrimExpr min = cast(dtype, for_node->min);
PrimExpr extent = cast(dtype, for_node->extent);
PrimExpr max_threads = IntImm(dtype, max_threads_per_block_);
PrimExpr num_blocks = ceildiv(extent, max_threads);

Var tx_var("threadIdx.x", dtype);
Var bx_var("blockIdx.x", dtype);
IterVar tx_iter(Range::FromMinExtent(IntImm(dtype, 0), max_threads), tx_var,
IterVarType::kThreadIndex, "threadIdx.x");
IterVar bx_iter(Range::FromMinExtent(IntImm(dtype, 0), num_blocks), bx_var,
IterVarType::kThreadIndex, "blockIdx.x");

PrimExpr global_idx = cast(dtype, bx_var * max_threads + tx_var);
PrimExpr mapped_idx = cast(dtype, min + global_idx);
Stmt mapped_body = Substitute(for_node->body, {{Var(for_node->loop_var), mapped_idx}});
mapped_body = IfThenElse(global_idx < extent, mapped_body, Evaluate(IntImm(DataType::Int(32), 0)));

Stmt body_with_tx = AttrStmt(tx_iter, tirx::attr::thread_extent, max_threads, mapped_body);
Stmt body_with_bx = AttrStmt(bx_iter, tirx::attr::thread_extent, num_blocks, body_with_tx);
return body_with_bx;
}

Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kThreadBinding) {
bool prev = in_thread_env_;
in_thread_env_ = true;
Stmt ret = StmtExprMutator::VisitStmt_(op);
in_thread_env_ = prev;
return ret;
}
if (op->kind != ForKind::kParallel) {
return StmtExprMutator::VisitStmt_(op);
}
if (in_parallel_loop_) {
return StmtExprMutator::VisitStmt_(op);
}
bool prev_in_parallel = in_parallel_loop_;
in_parallel_loop_ = true;
For updated = Downcast<For>(StmtExprMutator::VisitStmt_(op));
in_parallel_loop_ = prev_in_parallel;
return TransformParallelFor(updated.get());
}

int64_t max_threads_per_block_;
bool in_thread_env_{false};
bool in_parallel_loop_{false};
};

} // namespace

namespace transform {

Pass BindParallelLoopsToThreads() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
Target target = opt_target.value_or(Target::Current(/*allow_none=*/true));
if (!target.defined() || !IsGpuDeviceType(target->GetTargetDeviceType())) {
return f;
}

int64_t max_threads_per_block = 1024;
if (auto opt_max_threads = target->GetAttr<Integer>("max_num_threads")) {
max_threads_per_block = opt_max_threads.value()->value;
}

PrimFuncNode* n = f.CopyOnWrite();
n->body = ParallelLoopToThreadBindingMutator(max_threads_per_block)(n->body);
return f;
};

return CreatePrimFuncPass(pass_func, 0, "tirx.BindParallelLoopsToThreads", {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tirx.transform.BindParallelLoopsToThreads", BindParallelLoopsToThreads);
}

} // namespace transform
} // namespace tirx
} // namespace tvm