diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index bc3bf14bf14..4e053f5796c 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -197,6 +197,11 @@ def is_node_supported( return r def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901 + # Check if tensor node dtype is supported by vulkan + if utils.is_tensor_node(node) and not utils.io_dtypes_are_supported(node): + self.log_skip(node, "dtype not supported") + return False + if node.op == "call_function": # Apply nn module allowlist and blocklist if self.nn_module_allowlist is not None: diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 43ea6c7ce30..128015af4f2 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -12,14 +12,13 @@ from typing import cast, List, Optional, Union import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema - import torch - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) from executorch.backends.vulkan.utils import ( + get_vk_datatype, is_constant, is_get_attr_node, is_mutable_buffer_node, @@ -29,7 +28,6 @@ ) from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir.backend.utils import DelegateMappingBuilder - from executorch.exir.tensor import TensorSpec from torch._export.utils import get_buffer, get_param, is_buffer, is_param from torch.export import ExportedProgram @@ -71,27 +69,6 @@ def __init__( # For logging self.seen_ops = set() - @staticmethod - def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: - if torch_dtype == torch.bool: - return vk_graph_schema.VkDataType.BOOL - elif torch_dtype == torch.uint8: - return vk_graph_schema.VkDataType.UINT8 - elif torch_dtype == torch.int8: - return vk_graph_schema.VkDataType.INT8 - elif torch_dtype == torch.int32: - return vk_graph_schema.VkDataType.INT32 - elif torch_dtype == torch.int64: - return vk_graph_schema.VkDataType.INT64 - elif torch_dtype == torch.float16: - return vk_graph_schema.VkDataType.FLOAT16 - elif torch_dtype == torch.float32: - return vk_graph_schema.VkDataType.FLOAT32 - elif torch_dtype == torch.float64: - return vk_graph_schema.VkDataType.FLOAT64 - else: - raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") - def get_constant(self, node: Node) -> Optional[torch.Tensor]: """ Returns the constant associated with the given node in the exported program. @@ -275,8 +252,8 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: effective_dtype if constant_id >= 0 else self.get_staging_dtype(spec.dtype) ) - datatype = self.get_vk_datatype(effective_dtype) - staging_datatype = self.get_vk_datatype(staging_dtype) + datatype = get_vk_datatype(effective_dtype) + staging_datatype = get_vk_datatype(staging_dtype) new_id = len(self.values) self.values.append( diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 2ca2ddf19b7..e19c3ac1891 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. import operator -from typing import Any, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkDataType, VkMemoryLayout, VkStorageType, ) @@ -42,6 +43,17 @@ "quantize_affine.default", } +_VULKAN_DTYPES: Dict[torch.dtype, VkDataType] = { + torch.bool: VkDataType.BOOL, + torch.uint8: VkDataType.UINT8, + torch.int8: VkDataType.INT8, + torch.int32: VkDataType.INT32, + torch.int64: VkDataType.INT64, + torch.float16: VkDataType.FLOAT16, + torch.float32: VkDataType.FLOAT32, + torch.float64: VkDataType.FLOAT64, +} + ## ## Node type determination ## @@ -237,6 +249,73 @@ def num_tensors_in_node(node: torch.fx.Node) -> int: return 0 +def get_vk_datatype(torch_dtype: torch.dtype) -> VkDataType: + """ + Returns Vulkan dtype corresponding to torch dtype + """ + if torch_dtype not in _VULKAN_DTYPES: + raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") + + return _VULKAN_DTYPES[torch_dtype] + + +def output_dtypes_are_supported(node: torch.fx.Node) -> bool: + """ + Returns true if the output of the given tensor node has dtype that + is supported by the Vulkan backend. + """ + if not is_tensor_node(node): + return True + + # The val metadata must exist after previous check + node_val = node.meta.get("val", None) + assert node_val is not None + + # Get all the tensor dtypes in the node + tensor_dtypes = [] + if isinstance(node_val, FakeTensor): + tensor_dtypes = [node_val.dtype] + elif isinstance(node_val, list) or isinstance(node_val, tuple): + tensor_dtypes = [x.dtype for x in node_val] + + # Verify that all the tensor_dtypes are in vk_torch_dtypes + return all(dtype in _VULKAN_DTYPES for dtype in tensor_dtypes) + + +def input_dtypes_are_supported(node: torch.fx.Node) -> bool: + """ + Returns true if all the inputs to the given tensor node have dtype that + is supported by the Vulkan backend. + """ + if not is_tensor_node(node): + return True + + # Iterate over all the args of the node + for arg_node in node.args: + # The arg could be a single node, or a list (e.g., first arg of cat) + if isinstance(arg_node, torch.fx.Node): + if not output_dtypes_are_supported(arg_node): + return False + elif isinstance(arg_node, (list, tuple)): + if not all(output_dtypes_are_supported(x) for x in arg_node): + return False + + return True + + +def io_dtypes_are_supported(node: torch.fx.Node) -> bool: + """ + Returns true if all the inputs and outputs of the given tensor node have + dtype that is supported by the Vulkan backend. + """ + if not output_dtypes_are_supported(node): + return False + if not input_dtypes_are_supported(node): + return False + + return True + + def tensor_node_is_bool(node: torch.fx.Node) -> bool: """ Returns true if a given node contains a tensor with bool dtype