Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(WEBGPU_SRCS
runtime/ops/sdpa/Sdpa.cpp
runtime/ops/select_as_symint/SelectAsSymint.cpp
runtime/ops/quantized_linear/QuantizedLinear.cpp
runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
)

add_library(webgpu_backend ${WEBGPU_SRCS})
Expand Down
6 changes: 4 additions & 2 deletions backends/webgpu/runtime/WebGPUBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ Error WebGPUBackend::execute(
const size_t num_outputs = graph->output_ids().size();

// Copy inputs from EValue tensors to GPU buffers
std::vector<std::pair<const void*, size_t>> inputs;
std::vector<InputData> inputs;
inputs.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; i++) {
const auto& tensor = args[i]->toTensor();
inputs.emplace_back(tensor.const_data_ptr(), tensor.nbytes());
const bool host_is_int64 =
tensor.scalar_type() == executorch::aten::ScalarType::Long;
inputs.push_back({tensor.const_data_ptr(), tensor.nbytes(), host_is_int64});
}
graph->copy_inputs(inputs);

Expand Down
57 changes: 48 additions & 9 deletions backends/webgpu/runtime/WebGPUGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ size_t vk_datatype_size(vkgraph::VkDataType dtype) {
}
}

bool vk_datatype_is_int(vkgraph::VkDataType dtype) {
switch (dtype) {
case vkgraph::VkDataType::BOOL:
case vkgraph::VkDataType::UINT8:
case vkgraph::VkDataType::INT8:
case vkgraph::VkDataType::INT32:
case vkgraph::VkDataType::INT64:
return true;
default:
return false;
}
}

} // namespace

WebGPUGraph::WebGPUGraph() = default;
Expand All @@ -61,7 +74,7 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) {
}

void WebGPUGraph::update_symints_from_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs) {
const std::vector<InputData>& inputs) {
for (const auto& src : symint_sources_) {
int pos = -1;
for (size_t i = 0; i < input_ids_.size(); i++) {
Expand Down Expand Up @@ -100,8 +113,8 @@ void WebGPUGraph::update_symints_from_inputs(
// Reads the [0,..,index,..,0] element; symint sources are scalar-ish.
const int64_t offset = static_cast<int64_t>(index) * stride;
// elem_size back-derived from build-time numel (sources are static-shaped).
const void* host = inputs[pos].first;
const size_t elem_size = inputs[pos].second / static_cast<size_t>(numel);
const void* host = inputs[pos].data;
const size_t elem_size = inputs[pos].nbytes / static_cast<size_t>(numel);
int32_t val;
if (elem_size == sizeof(int64_t)) {
val = static_cast<int32_t>(static_cast<const int64_t*>(host)[offset]);
Expand Down Expand Up @@ -248,7 +261,9 @@ void WebGPUGraph::build(
numel *= dims->Get(j);
}
}
tensor.nbytes = numel * vk_datatype_size(vk_tensor->datatype());
tensor.elem_size = vk_datatype_size(vk_tensor->datatype());
tensor.is_int = vk_datatype_is_int(vk_tensor->datatype());
tensor.nbytes = numel * tensor.elem_size;

int constant_id = vk_tensor->constant_id();
int mem_obj_id = vk_tensor->mem_obj_id();
Expand Down Expand Up @@ -484,16 +499,40 @@ WGPUBindGroupLayout WebGPUGraph::get_or_create_bgl(
return bgl;
}

void WebGPUGraph::copy_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs) {
void WebGPUGraph::copy_inputs(const std::vector<InputData>& inputs) {
for (size_t i = 0; i < inputs.size() && i < input_ids_.size(); i++) {
if (inputs[i].second == 0) {
const InputData& in = inputs[i];
if (in.nbytes == 0) {
continue;
}
int tid = input_ids_[i];
const auto& tensor = tensors_[tid];
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, inputs[i].first, inputs[i].second);

// Fast path: host and GPU element types match byte-for-byte.
if (in.nbytes == tensor.nbytes) {
wgpuQueueWriteBuffer(queue_, tensor.buffer, 0, in.data, tensor.nbytes);
continue;
}

// Narrow int64 host indices into the int32 buffer (mirrors Vulkan).
const bool buffer_is_int32 = tensor.is_int && tensor.elem_size == 4;
if (in.host_is_int64 && buffer_is_int32 && in.nbytes == tensor.nbytes * 2) {
const size_t numel = tensor.nbytes / 4;
const int64_t* src = static_cast<const int64_t*>(in.data);
std::vector<int32_t> narrowed(numel);
for (size_t e = 0; e < numel; e++) {
narrowed[e] = static_cast<int32_t>(src[e]);
}
wgpuQueueWriteBuffer(
queue_, tensor.buffer, 0, narrowed.data(), tensor.nbytes);
continue;
}

throw std::runtime_error(
"WebGPU: unsupported input copy for input " + std::to_string(i) +
" (host " + std::to_string(in.nbytes) + " bytes" +
(in.host_is_int64 ? " int64" : "") + " vs buffer " +
std::to_string(tensor.nbytes) + " bytes)");
}
}

Expand Down
15 changes: 12 additions & 3 deletions backends/webgpu/runtime/WebGPUGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ struct WebGPUTensor {
WGPUBuffer buffer = nullptr;
std::vector<int64_t> dims;
size_t nbytes = 0;
// Serialized (GPU-side) element type, used to narrow wider host inputs.
size_t elem_size = 0;
bool is_int = false;
};

// Host-side view of one graph input, passed to copy_inputs.
struct InputData {
const void* data = nullptr;
size_t nbytes = 0;
bool host_is_int64 = false;
};

struct WebGPUDispatch {
Expand Down Expand Up @@ -75,7 +85,7 @@ class WebGPUGraph {
const executorch::runtime::NamedDataMap* named_data_map = nullptr);

// Copy input tensor data from host pointers into GPU buffers.
void copy_inputs(const std::vector<std::pair<const void*, size_t>>& inputs);
void copy_inputs(const std::vector<InputData>& inputs);

// Execute all recorded dispatches.
void execute();
Expand Down Expand Up @@ -138,8 +148,7 @@ class WebGPUGraph {
}

// Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl.
void update_symints_from_inputs(
const std::vector<std::pair<const void*, size_t>>& inputs);
void update_symints_from_inputs(const std::vector<InputData>& inputs);

// Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize.
void add_resize_hook(int symint_id, std::function<void(WebGPUGraph&)> fn) {
Expand Down
Loading
Loading