Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ def is_node_supported(
return r

def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
# Check if tensor node dtype is supported by vulkan
if utils.is_tensor_node(node) and not utils.io_dtypes_are_supported(node):
self.log_skip(node, "dtype not supported")
return False

if node.op == "call_function":
# Apply nn module allowlist and blocklist
if self.nn_module_allowlist is not None:
Expand Down
29 changes: 3 additions & 26 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
from typing import cast, List, Optional, Union

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

import torch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)
from executorch.backends.vulkan.utils import (
get_vk_datatype,
is_constant,
is_get_attr_node,
is_mutable_buffer_node,
Expand All @@ -29,7 +28,6 @@
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir.backend.utils import DelegateMappingBuilder

from executorch.exir.tensor import TensorSpec
from torch._export.utils import get_buffer, get_param, is_buffer, is_param
from torch.export import ExportedProgram
Expand Down Expand Up @@ -71,27 +69,6 @@ def __init__(
# For logging
self.seen_ops = set()

@staticmethod
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
if torch_dtype == torch.bool:
return vk_graph_schema.VkDataType.BOOL
elif torch_dtype == torch.uint8:
return vk_graph_schema.VkDataType.UINT8
elif torch_dtype == torch.int8:
return vk_graph_schema.VkDataType.INT8
elif torch_dtype == torch.int32:
return vk_graph_schema.VkDataType.INT32
elif torch_dtype == torch.int64:
return vk_graph_schema.VkDataType.INT64
elif torch_dtype == torch.float16:
return vk_graph_schema.VkDataType.FLOAT16
elif torch_dtype == torch.float32:
return vk_graph_schema.VkDataType.FLOAT32
elif torch_dtype == torch.float64:
return vk_graph_schema.VkDataType.FLOAT64
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

def get_constant(self, node: Node) -> Optional[torch.Tensor]:
"""
Returns the constant associated with the given node in the exported program.
Expand Down Expand Up @@ -275,8 +252,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int:
effective_dtype if constant_id >= 0 else self.get_staging_dtype(spec.dtype)
)

datatype = self.get_vk_datatype(effective_dtype)
staging_datatype = self.get_vk_datatype(staging_dtype)
datatype = get_vk_datatype(effective_dtype)
staging_datatype = get_vk_datatype(staging_dtype)

new_id = len(self.values)
self.values.append(
Expand Down
81 changes: 80 additions & 1 deletion backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.

import operator
from typing import Any, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import torch
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkDataType,
VkMemoryLayout,
VkStorageType,
)
Expand Down Expand Up @@ -42,6 +43,17 @@
"quantize_affine.default",
}

_VULKAN_DTYPES: Dict[torch.dtype, VkDataType] = {
torch.bool: VkDataType.BOOL,
torch.uint8: VkDataType.UINT8,
torch.int8: VkDataType.INT8,
torch.int32: VkDataType.INT32,
torch.int64: VkDataType.INT64,
torch.float16: VkDataType.FLOAT16,
torch.float32: VkDataType.FLOAT32,
torch.float64: VkDataType.FLOAT64,
}

##
## Node type determination
##
Expand Down Expand Up @@ -237,6 +249,73 @@ def num_tensors_in_node(node: torch.fx.Node) -> int:
return 0


def get_vk_datatype(torch_dtype: torch.dtype) -> VkDataType:
"""
Returns Vulkan dtype corresponding to torch dtype
"""
if torch_dtype not in _VULKAN_DTYPES:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

return _VULKAN_DTYPES[torch_dtype]


def output_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if the output of the given tensor node has dtype that
is supported by the Vulkan backend.
"""
if not is_tensor_node(node):
return True

# The val metadata must exist after previous check
node_val = node.meta.get("val", None)
assert node_val is not None

# Get all the tensor dtypes in the node
tensor_dtypes = []
if isinstance(node_val, FakeTensor):
tensor_dtypes = [node_val.dtype]
elif isinstance(node_val, list) or isinstance(node_val, tuple):
tensor_dtypes = [x.dtype for x in node_val]

# Verify that all the tensor_dtypes are in vk_torch_dtypes
return all(dtype in _VULKAN_DTYPES for dtype in tensor_dtypes)


def input_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if all the inputs to the given tensor node have dtype that
is supported by the Vulkan backend.
"""
if not is_tensor_node(node):
return True

# Iterate over all the args of the node
for arg_node in node.args:
# The arg could be a single node, or a list (e.g., first arg of cat)
if isinstance(arg_node, torch.fx.Node):
if not output_dtypes_are_supported(arg_node):
return False
elif isinstance(arg_node, (list, tuple)):
if not all(output_dtypes_are_supported(x) for x in arg_node):
return False

return True


def io_dtypes_are_supported(node: torch.fx.Node) -> bool:
"""
Returns true if all the inputs and outputs of the given tensor node have
dtype that is supported by the Vulkan backend.
"""
if not output_dtypes_are_supported(node):
return False
if not input_dtypes_are_supported(node):
return False

return True


def tensor_node_is_bool(node: torch.fx.Node) -> bool:
"""
Returns true if a given node contains a tensor with bool dtype
Expand Down
Loading