Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions core/conversion/converters/impl/internal_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <limits>
#include <vector>
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"
#include "torch/torch.h"
Expand All @@ -18,20 +20,18 @@ auto linear_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
auto in = args[0].ITensorOrFreeze(ctx);
auto out = in;
if (in->getType() == nvinfer1::DataType::kBOOL) {
auto not_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);
TORCHTRT_CHECK(not_layer, "Unable to create not layer for attn_bias_from_attn_mask");
not_layer->setName((util::node_info(n) + "_not").c_str());
auto neg_inf = torch::tensor(-std::numeric_limits<float>::infinity());
auto neg_inf_itensor = tensor_to_const(ctx, neg_inf);
auto prod_layer = add_elementwise(
std::vector<int64_t> singleton_dims(in->getDimensions().nbDims, 1);
auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto zero = tensor_to_const(
ctx, torch::full(singleton_dims, 0.0f, options), util::node_info(n) + "_zero");
auto neg_inf = tensor_to_const(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
not_layer->getOutput(0),
neg_inf_itensor,
util::node_info(n) + "_mul");
auto add_layer = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kSUM, prod_layer->getOutput(0), in, util::node_info(n) + "_add");
out = add_layer->getOutput(0);
torch::full(singleton_dims, -std::numeric_limits<float>::infinity(), options),
util::node_info(n) + "_neg_inf");
auto select_layer = ctx->net->addSelect(*in, *zero, *neg_inf);
TORCHTRT_CHECK(select_layer, "Unable to create select layer for attn_bias_from_attn_mask");
select_layer->setName(util::node_info(n).c_str());
out = select_layer->getOutput(0);
}
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out);
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@ TEST(Converters, ATenScaledDotProductAttnMaskFloatConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenScaledDotProductAttnMaskBoolDoesNotProduceNaN) {
const auto graph = R"IR(
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
%0 : float = prim::Constant[value=0.]()
%false : bool = prim::Constant[value=0]()
%scale : NoneType = prim::Constant()
%enable_gqa : bool = prim::Constant[value=0]()
%3 : Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %0, %false, %scale, %enable_gqa)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, &*g);

auto query = at::arange(16, {at::kCUDA}).to(at::kFloat).reshape({1, 1, 4, 4}) / 16.0;
auto key = at::arange(16, {at::kCUDA}).to(at::kFloat).reshape({1, 1, 4, 4}) / 13.0;
auto value = at::arange(16, {at::kCUDA}).to(at::kFloat).reshape({1, 1, 4, 4}) / 11.0;
auto attn_mask = at::tensor(
{1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1},
at::TensorOptions().dtype(at::kBool).device(at::kCUDA))
.reshape({1, 1, 4, 4});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {query, key, value, attn_mask});

torch_tensorrt::core::lowering::passes::UnpackScaledDotProductAttention(g);

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {query, key, value, attn_mask});

ASSERT_FALSE(torch::isnan(jit_results[0]).any().item<bool>());
ASSERT_FALSE(torch::isnan(trt_results[0]).any().item<bool>());
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0]));
}

TEST(Converters, ATenScaledDotProductAttnMaskIntConvertsCorrectly) {
const auto graph = R"IR(
graph(%query : Tensor, %key : Tensor, %value : Tensor, %attn_mask : Tensor):
Expand Down
Loading