From 20d4d6ac5ebba6f46145171a0026f0488c6ab6e0 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 13 Jan 2026 09:52:05 -0500 Subject: [PATCH 1/2] export/lower parakeet to Metal backend handle grouped convolution implement aoti_torch_mps_bmm_out linear decomposition enable non-zero storage offset in aoti_torch__reinterpret_tensor descriptive error for unimplemented shim functions implement aoti_torch_new_tensor_handle fix lint --- backends/aoti/common_shims.cpp | 10 +- backends/apple/metal/metal_backend.py | 2 +- backends/apple/metal/runtime/shims/et_metal.h | 1 + .../apple/metal/runtime/shims/et_metal.mm | 12 + .../apple/metal/runtime/shims/et_metal_ops.h | 10 + .../apple/metal/runtime/shims/et_metal_ops.mm | 315 +++++++++++++++++- backends/apple/metal/runtime/shims/memory.cpp | 116 ++++++- examples/models/parakeet/README.md | 35 +- .../models/parakeet/export_parakeet_tdt.py | 148 ++++++-- 9 files changed, 593 insertions(+), 56 deletions(-) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index abfde86db6d..7c88e4cfb5b 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -218,7 +218,7 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) { (void)tensor; (void)ret_size; - throw std::runtime_error("Not implemented"); + throw std::runtime_error("Not implemented: aoti_torch_get_storage_size"); return Error::Internal; } @@ -226,7 +226,8 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor) { (void)self; (void)ret_new_tensor; - throw std::runtime_error("Not implemented"); + throw std::runtime_error( + "Not implemented: aoti_torch_clone_preserve_strides"); return Error::Internal; } @@ -234,7 +235,7 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) { (void)self; (void)ret_new_tensor; - throw std::runtime_error("Not implemented"); + throw std::runtime_error("Not implemented: aoti_torch_clone"); return Error::Internal; } @@ -257,7 +258,8 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( (void)device_type; (void)device_index; (void)ret_new_tensor; - throw std::runtime_error("Not implemented"); + throw std::runtime_error( + "Not implemented: aoti_torch_create_tensor_from_blob"); return Error::Internal; } diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 1d86cfb8447..fde0410cca3 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -31,7 +31,7 @@ def get_device_name(cls) -> str: @classmethod def get_supported_fallback_kernels(cls) -> Dict[str, Any]: return { - "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_bmm_out": None, "aoti_torch_mps_convolution": None, "aoti_torch_mps_mm_out": None, "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index 1c61499b242..e4d71fed72e 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -379,6 +379,7 @@ int metal_copy_memory( bool src_is_device, bool dst_is_device); void metal_cleanup_resources(); +void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer); // Helper functions to access Metal objects MTLDevice_t get_metal_device(); diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index f7d37c152ce..4f4464a534c 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -113,6 +113,18 @@ void metal_cleanup_resources() { } } +void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer) { + id device = get_metal_device(); + id subBuffer = [device newBufferWithBytesNoCopy:ptr + length:nbytes + options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared + deallocator:nil]; + + if (map_ptr_to_buffer) { + ptr_to_mtl_buffer[ptr] = subBuffer; // Map contents to buffer + } +} + bool metal_is_device_pointer(void* ptr) { return ptr_to_mtl_buffer.find(ptr) != ptr_to_mtl_buffer.end(); } diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.h b/backends/apple/metal/runtime/shims/et_metal_ops.h index 78bdb419ea4..fcc6dfc03da 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.h +++ b/backends/apple/metal/runtime/shims/et_metal_ops.h @@ -27,6 +27,16 @@ AOTITorchError aoti_torch_mps_mm_out( AOTITensorHandle self, AOTITensorHandle mat2); +/** + * ExecutorTorch implementation of aoti_torch_mps_bmm_out. + * Performs batched matrix multiplication: out = self @ mat2 + * All tensors must be 3-D with matching batch dimensions. + */ +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2); + /** * ExecutorTorch implementation of aoti_torch_mps_convolution. * Performs 2D convolution operation - matches PyTorch AOTI signature diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index da54dafb334..5b413728de5 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -626,6 +626,316 @@ AOTITorchError aoti_torch_mps_mm_out( } } +AOTITorchError aoti_torch_mps_bmm_out( + AOTITensorHandle out, + AOTITensorHandle self, + AOTITensorHandle mat2) { + + // Validate non-null handles + if (!out || !self || !mat2) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: null tensor handles"); + return Error::InvalidArgument; + } + + @autoreleasepool { + try { + // Convert AOTITensorHandle to ExecutorTorch tensors + auto out_tensor = reinterpret_cast(out); + auto self_tensor = reinterpret_cast(self); + auto mat2_tensor = reinterpret_cast(mat2); + + // Validate tensor dimensions - bmm requires 3-D tensors + if (self_tensor->dim() != 3 || mat2_tensor->dim() != 3 || out_tensor->dim() != 3) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: tensors must be 3-D. " + "Got self.dim=%zd (shape=[%d,%d,%d]), " + "mat2.dim=%zd (shape=[%d,%d,%d]), " + "out.dim=%zd (shape=[%d,%d,%d])", + self_tensor->dim(), + self_tensor->dim() > 0 ? (int)self_tensor->sizes()[0] : 0, + self_tensor->dim() > 1 ? (int)self_tensor->sizes()[1] : 0, + self_tensor->dim() > 2 ? (int)self_tensor->sizes()[2] : 0, + mat2_tensor->dim(), + mat2_tensor->dim() > 0 ? (int)mat2_tensor->sizes()[0] : 0, + mat2_tensor->dim() > 1 ? (int)mat2_tensor->sizes()[1] : 0, + mat2_tensor->dim() > 2 ? (int)mat2_tensor->sizes()[2] : 0, + out_tensor->dim(), + out_tensor->dim() > 0 ? (int)out_tensor->sizes()[0] : 0, + out_tensor->dim() > 1 ? (int)out_tensor->sizes()[1] : 0, + out_tensor->dim() > 2 ? (int)out_tensor->sizes()[2] : 0); + return Error::InvalidArgument; + } + + int64_t B = self_tensor->sizes()[0]; // batch size + int64_t M = self_tensor->sizes()[1]; // rows of self + int64_t K = self_tensor->sizes()[2]; // cols of self / rows of mat2 + int64_t N = mat2_tensor->sizes()[2]; // cols of mat2 + + // Validate shape constraints + // self: [B, M, K], mat2: [B, K, N], out: [B, M, N] + if (mat2_tensor->sizes()[0] != B) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: batch size mismatch. " + "Expected mat2[0]=%d to match self[0]=%lld. " + "self.shape=[%lld,%lld,%lld], mat2.shape=[%d,%d,%d]", + (int)mat2_tensor->sizes()[0], (long long)B, + (long long)B, (long long)M, (long long)K, + (int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + if (mat2_tensor->sizes()[1] != K) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: incompatible matrix dimensions for bmm. " + "Expected mat2[1]=%d to match self[2]=%lld. " + "Cannot multiply [%lld,%lld,%lld] @ [%d,%d,%d]", + (int)mat2_tensor->sizes()[1], (long long)K, + (long long)B, (long long)M, (long long)K, + (int)mat2_tensor->sizes()[0], (int)mat2_tensor->sizes()[1], (int)mat2_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + if (out_tensor->sizes()[0] != B || out_tensor->sizes()[1] != M || out_tensor->sizes()[2] != N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: output shape mismatch. " + "Expected out.shape=[%lld,%lld,%lld], got [%d,%d,%d]", + (long long)B, (long long)M, (long long)N, + (int)out_tensor->sizes()[0], (int)out_tensor->sizes()[1], (int)out_tensor->sizes()[2]); + return Error::InvalidArgument; + } + + // Validate dtype consistency + int32_t self_dtype = static_cast(self_tensor->scalar_type()); + int32_t mat2_dtype = static_cast(mat2_tensor->scalar_type()); + int32_t out_dtype = static_cast(out_tensor->scalar_type()); + + if (self_dtype != mat2_dtype || self_dtype != out_dtype) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: dtype mismatch. " + "All tensors must have same dtype. Got self.dtype=%d, mat2.dtype=%d, out.dtype=%d", + self_dtype, mat2_dtype, out_dtype); + return Error::InvalidArgument; + } + + int32_t dtype = self_dtype; + + // Validate layout: BMM requires strictly contiguous 3D tensors + // For shape [B, M, K], contiguous strides MUST be [M*K, K, 1] + // + // Why strict contiguity is required: + // - MPSGraphTensorData initWithMTLBuffer:shape:dataType: interprets the MTLBuffer + // as containing dense row-major data for the given shape + // - Non-contiguous layouts (transposed, views with strides, etc.) have different + // memory layouts that don't match what MPS expects + // - This would result in SILENT WRONG RESULTS + // - This is an _out op: we must NOT create implicit copies + // - Policy: Reject non-contiguous inputs explicitly (transposed/view tensors unsupported) + // + // Limitation: This implementation does not explicitly check storage offset (no API available). + // Tensors with non-zero storage offsets are not explicitly rejected but may work if they + // happen to have contiguous strides. Users should ensure tensors are base tensors without offsets. + auto self_strides = self_tensor->strides(); + auto mat2_strides = mat2_tensor->strides(); + auto out_strides = out_tensor->strides(); + + // Check self tensor is contiguous [B, M, K] with strides [M*K, K, 1] + if (self_strides[2] != 1 || self_strides[1] != K || self_strides[0] != M * K) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: self tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(M * K), (long long)K, (long long)B, (long long)M, (long long)K, + self_strides[0], self_strides[1], self_strides[2]); + return Error::InvalidArgument; + } + + // Check mat2 tensor is contiguous [B, K, N] with strides [K*N, N, 1] + if (mat2_strides[2] != 1 || mat2_strides[1] != N || mat2_strides[0] != K * N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: mat2 tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(K * N), (long long)N, (long long)B, (long long)K, (long long)N, + mat2_strides[0], mat2_strides[1], mat2_strides[2]); + return Error::InvalidArgument; + } + + // Check out tensor is contiguous [B, M, N] with strides [M*N, N, 1] + if (out_strides[2] != 1 || out_strides[1] != N || out_strides[0] != M * N) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: out tensor must be contiguous. " + "Only dense row-major layout supported; transposed/view tensors are unsupported. " + "Expected strides=[%lld,%lld,1] for shape=[%lld,%lld,%lld], got strides=[%d,%d,%d].", + (long long)(M * N), (long long)N, (long long)B, (long long)M, (long long)N, + out_strides[0], out_strides[1], out_strides[2]); + return Error::InvalidArgument; + } + + // Get Metal stream and device + ETMetalStream* stream = getCurrentMetalStream(); + if (!stream) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get current Metal stream"); + return Error::Internal; + } + + id device = get_metal_device(); + if (!device) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal device"); + return Error::Internal; + } + (void)device; // Used for validation, consistent with other ops + + // Get Metal buffers for input and output tensors + id self_buffer = get_mtl_buffer(self_tensor, "aoti_torch_mps_bmm_out", "self"); + id mat2_buffer = get_mtl_buffer(mat2_tensor, "aoti_torch_mps_bmm_out", "mat2"); + id out_buffer = get_mtl_buffer(out_tensor, "aoti_torch_mps_bmm_out", "out"); + + // Validate buffers are non-null + if (!self_buffer || !mat2_buffer || !out_buffer) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Failed to get Metal buffers. " + "self_buffer=%p, mat2_buffer=%p, out_buffer=%p", + self_buffer, mat2_buffer, out_buffer); + return Error::Internal; + } + + // End any existing kernel coalescing to ensure clean state + // (consistent with mm_out and conv pattern) + stream->endKernelCoalescing(); + + // Map dtype to MPS type and validate support + // Note: Only FLOAT32 and BFLOAT16 are supported in Metal backend (see utils.h) + // FLOAT16 is not in SupportedDTypes enum and is not supported + MPSDataType mps_dtype; + + if (dtype == static_cast(SupportedDTypes::FLOAT32)) { + mps_dtype = MPSDataTypeFloat32; + } else if (dtype == static_cast(SupportedDTypes::BFLOAT16)) { + mps_dtype = MPSDataTypeBFloat16; + } else { + ET_LOG(Error, "aoti_torch_mps_bmm_out: Unsupported data type: %d. " + "Supported types: FLOAT32 (%d), BFLOAT16 (%d)", + dtype, + static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BFLOAT16)); + return Error::InvalidArgument; + } + + // Define shapes for graph placeholders and tensor data + NSArray* selfShape = @[@(B), @(M), @(K)]; + NSArray* mat2Shape = @[@(B), @(K), @(N)]; + NSArray* outShape = @[@(B), @(M), @(N)]; + + // Create cache key for this batched matrix multiplication + // Cache key includes: op_name, shape params {B, M, K, N}, dtype, transpose_flag + // This allows reuse when same BMM shape/dtype is called repeatedly + GraphCacheKey cache_key; + cache_key.op_name = "bmm"; + cache_key.shape_params = {B, M, K, N}; + cache_key.dtype = dtype; + cache_key.transpose_flag = false; // BMM has no transpose handling + + // Check if we have a cached graph + MPSGraph* mpsGraph = nullptr; + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor* selfPlaceholder = nil; + MPSGraphTensor* mat2Placeholder = nil; + + auto cache_it = graph_cache.find(cache_key); + if (cache_it != graph_cache.end()) { + // Cache hit - reuse compiled graph and tensor references + CachedGraph& cached = cache_it->second; + mpsGraph = cached.graph; + selfPlaceholder = cached.input1; + mat2Placeholder = cached.input2; + outputTensor = cached.output; + + cache_stats.hits++; + cache_stats.logStats(); + + } else { + // Cache miss - create and compile new graph + mpsGraph = [MPSGraph new]; + cache_stats.misses++; + cache_stats.logStats(); + + // Create 3D placeholders for batched matrices + // These represent the logical shapes for the batched matrix multiplication + selfPlaceholder = [mpsGraph placeholderWithShape:selfShape + dataType:mps_dtype + name:@"self"]; + mat2Placeholder = [mpsGraph placeholderWithShape:mat2Shape + dataType:mps_dtype + name:@"mat2"]; + + // MPSGraph matrixMultiplication handles batched case natively when given 3D tensors + // For 3D inputs [B,M,K] @ [B,K,N] -> [B,M,N] + outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfPlaceholder + secondaryTensor:mat2Placeholder + name:@"bmm_result"]; + + // Cache the compiled graph and tensor references for reuse + CachedGraph cached_graph; + cached_graph.graph = mpsGraph; + cached_graph.input1 = selfPlaceholder; + cached_graph.input2 = mat2Placeholder; + cached_graph.input3 = nil; // No third input for BMM + cached_graph.output = outputTensor; + graph_cache[cache_key] = cached_graph; + + } // End of cache miss/hit block + + // Create feeds dictionary for graph execution + NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; + + // Create MPSGraphTensorData objects for input tensors + // These wrap the MTLBuffers with the shape information + // Initialize to nil for safe cleanup in exception path + MPSGraphTensorData* selfData = nil; + MPSGraphTensorData* mat2Data = nil; + MPSGraphTensorData* outputData = nil; + + selfData = [[MPSGraphTensorData alloc] initWithMTLBuffer:self_buffer + shape:selfShape + dataType:mps_dtype]; + mat2Data = [[MPSGraphTensorData alloc] initWithMTLBuffer:mat2_buffer + shape:mat2Shape + dataType:mps_dtype]; + + feeds[selfPlaceholder] = selfData; + feeds[mat2Placeholder] = mat2Data; + + // Create output tensor data + outputData = [[MPSGraphTensorData alloc] initWithMTLBuffer:out_buffer + shape:outShape + dataType:mps_dtype]; + + // Build results dictionary + NSDictionary* results = @{ + outputTensor: outputData + }; + + // Execute the batched matrix multiplication + @try { + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + } @catch (NSException *exception) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s", + [[exception name] UTF8String], [[exception reason] UTF8String]); + // Guard releases against nil + if (selfData) [selfData release]; + if (mat2Data) [mat2Data release]; + if (outputData) [outputData release]; + return Error::Internal; + } + + // Release MPSGraphTensorData objects + [selfData release]; + [mat2Data release]; + [outputData release]; + + return Error::Ok; + + } catch (const std::exception& e) { + ET_LOG(Error, "aoti_torch_mps_bmm_out exception: %s", e.what()); + return Error::Internal; + } catch (...) { + ET_LOG(Error, "aoti_torch_mps_bmm_out: unknown exception"); + return Error::Internal; + } + } +} + AOTITorchError aoti_torch_mps_convolution( AOTITensorHandle input, AOTITensorHandle weight, @@ -827,10 +1137,13 @@ AOTITorchError aoti_torch_mps_convolution( } ET_LOG(Debug, "aoti_torch_mps_convolution: mps_dtype=%d, element_size=%zu", mps_dtype, element_size); + // Get weight's input channel dimension from the weight tensor (not from input) + // For grouped convolutions, weight shape is [C_out, C_in/groups, kH, kW] + int64_t weight_C_in = weight_tensor->sizes()[1]; // This handles grouped convs correctly // Define tensor shapes for placeholders (needed for both cache hit and miss) NSArray* inputShape = @[@(N), @(C_in), @(H_in), @(W_in)]; - NSArray* weightShape = @[@(C_out), @(C_in), @(kernel_h), @(kernel_w)]; + NSArray* weightShape = @[@(C_out), @(weight_C_in), @(kernel_h), @(kernel_w)]; // Create cache key for this convolution GraphCacheKey cache_key; diff --git a/backends/apple/metal/runtime/shims/memory.cpp b/backends/apple/metal/runtime/shims/memory.cpp index ebb5b7642e1..eae8e62beef 100644 --- a/backends/apple/metal/runtime/shims/memory.cpp +++ b/backends/apple/metal/runtime/shims/memory.cpp @@ -430,9 +430,6 @@ AOTITorchError aoti_torch__reinterpret_tensor( InvalidArgument, "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null"); - // Check if storage_offset is not 0 - return error if not - ET_CHECK_OK_OR_RETURN_ERROR(validate_storage_offset(storage_offset)); - // Get the device info from the source tensor to perform device_index // validation int32_t device_type = 0; @@ -470,6 +467,10 @@ AOTITorchError aoti_torch__reinterpret_tensor( "Memory address %p is not being tracked by reference counting system", data_ptr); + // Handle storage offset by adjusting the data pointer + void* adjusted_data = static_cast(data_ptr) + + (storage_offset * dtype_to_element_size(dtype)); + // Convert sizes using utility function from utils.h std::vector sizes = convert_sizes_to_vector(ndim, sizes_ptr); @@ -480,7 +481,7 @@ AOTITorchError aoti_torch__reinterpret_tensor( // Create new tensor view that reinterprets the same memory with different // shape/strides This creates a view, not a copy - the data pointer is shared std::shared_ptr tensor = executorch::extension::from_blob( - data_ptr, // Reuse the same memory from source tensor + adjusted_data, // Use adjusted data pointer with storage offset applied sizes, // New sizes with explicit SizesType strides, // New strides with explicit StridesType dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting @@ -496,11 +497,24 @@ AOTITorchError aoti_torch__reinterpret_tensor( *ret_new_tensor = tensor.get(); + if (adjusted_data != data_ptr) { + ET_LOG( + Debug, + "aoti_torch__reinterpret_tensor: Adjusted original_data=%p, storage_offset=%lld, element_size=%zu, adjusted_data=%p", + data_ptr, + storage_offset, + dtype_to_element_size(dtype), + adjusted_data); + + metal_buffer_nocopy(adjusted_data, tensor->nbytes(), true); + } + // Increment the reference count for this memory address only if it is owned // by tensor - memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + memory_to_n_tensor[adjusted_data] = + memory_to_n_tensor[adjusted_data] == NOT_OWN ? NOT_OWN - : memory_to_n_tensor[data_ptr] + 1; + : memory_to_n_tensor[adjusted_data] + 1; ET_LOG(Debug, "aoti_torch__reinterpret_tensor: successfull"); return Error::Ok; @@ -509,10 +523,92 @@ AOTITorchError aoti_torch__reinterpret_tensor( AOTITorchError aoti_torch_new_tensor_handle( Tensor* orig_handle, Tensor** new_handle) { - (void)orig_handle; - (void)new_handle; - throw std::runtime_error("Not implemented"); - return Error::Internal; + ET_LOG(Debug, "aoti_torch_new_tensor_handle: entered"); + + // Validate input parameters + ET_CHECK_OR_RETURN_ERROR( + orig_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: orig_handle is null"); + + ET_CHECK_OR_RETURN_ERROR( + new_handle != nullptr, + InvalidArgument, + "aoti_torch_new_tensor_handle failed: new_handle is null"); + + // Get metadata from the original tensor + int64_t* sizes_ptr; + int64_t* strides_ptr; + int32_t dtype; + int32_t device_type; + int32_t device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_sizes(orig_handle, &sizes_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_strides(orig_handle, &strides_ptr)); + ET_CHECK_OK_OR_RETURN_ERROR(aoti_torch_get_dtype(orig_handle, &dtype)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_type(orig_handle, &device_type)); + ET_CHECK_OK_OR_RETURN_ERROR( + aoti_torch_get_device_index(orig_handle, &device_index)); + + int64_t ndim = orig_handle->dim(); + + // Validate dtype + ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype)); + + // Ensure device_index is always 0 + ET_CHECK_OR_RETURN_ERROR( + device_index == 0, + InvalidArgument, + "device_index must be 0, got: %d", + device_index); + + // Get the original data pointer from the source tensor + void* data_ptr = orig_handle->mutable_data_ptr(); + ET_CHECK_OR_RETURN_ERROR( + data_ptr != nullptr, + InvalidArgument, + "Source tensor has null data pointer"); + + // Check if the given memory is in the map + auto memory_it = memory_to_n_tensor.find(data_ptr); + ET_CHECK_OR_RETURN_ERROR( + memory_it != memory_to_n_tensor.end(), + InvalidArgument, + "Memory address %p is not being tracked by reference counting system", + data_ptr); + + // Convert sizes and strides to vectors + auto sizes = convert_sizes_to_vector(ndim, sizes_ptr); + auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr); + + // Create new tensor that shares the same memory as the original + // This is similar to PyTorch's Tensor copy constructor - creates a new + // tensor object that shares the same underlying storage + std::shared_ptr tensor = executorch::extension::from_blob( + data_ptr, // Share the same memory from source tensor + sizes, // Same sizes as original + strides, // Same strides as original + dtype_to_scalar_type(dtype) // Same dtype as original + ); + + ET_CHECK_OR_RETURN_ERROR( + tensor != nullptr, InvalidArgument, "Failed to create new tensor handle"); + + // Store the tensor so it doesn't get destroyed + tensors.insert(tensor); + + *new_handle = tensor.get(); + + // Increment the reference count for this memory address only if it is owned + // by tensor + memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN + ? NOT_OWN + : memory_to_n_tensor[data_ptr] + 1; + + ET_LOG(Debug, "aoti_torch_new_tensor_handle: successfull"); + return Error::Ok; } // Cleanup function for clearing global state diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b27bc1f8a91..7a611c90e82 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -25,26 +25,47 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | -| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `metal`, `cuda`, `cuda-windows` (default: `portable`) | +| `--dtype` | Data type: `fp32`, `bf16` (default: `fp32`). Metal backend supports fp32 and bf16 only. | | `--audio` | Path to audio file for transcription test | **Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting. +### Metal Export (macOS) + +```bash +python export_parakeet_tdt.py --backend metal --output-dir ./parakeet_metal +``` + +This generates: +- `parakeet_tdt.pte` - The compiled model +- `aoti_metal_blob.ptd` - Metal kernel blob required at runtime +- `tokenizer.model` - SentencePiece tokenizer + ## C++ Runner ### Building -First, build ExecuTorch with the LLM preset from the executorch root directory: +First, build ExecuTorch with the appropriate preset from the executorch root directory: ```bash +# For CPU/XNNPACK cmake --workflow --preset llm-release + +# For Metal (macOS) +cmake --workflow --preset llm-debug-metal ``` Then build the parakeet runner: ```bash cd examples/models/parakeet + +# CPU/XNNPACK build cmake --workflow --preset parakeet-cpu + +# Metal build +cmake --workflow --preset parakeet-metal ``` Available presets: @@ -57,10 +78,18 @@ Available presets: From the executorch root directory: ```bash +# CPU/XNNPACK ./cmake-out/examples/models/parakeet/parakeet_runner \ --model_path examples/models/parakeet/parakeet_tdt_exports/parakeet_tdt.pte \ --audio_path /path/to/audio.wav \ --tokenizer_path examples/models/parakeet/parakeet_tdt_exports/tokenizer.model + +# Metal (include .ptd data file) +DYLD_LIBRARY_PATH=/usr/lib ./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path examples/models/parakeet/parakeet_metal/parakeet_tdt.pte \ + --data_path examples/models/parakeet/parakeet_metal/aoti_metal_blob.ptd \ + --audio_path /path/to/audio.wav \ + --tokenizer_path examples/models/parakeet/parakeet_metal/tokenizer.model ``` ### Runner Arguments @@ -70,4 +99,4 @@ From the executorch root directory: | `--model_path` | Path to Parakeet model (.pte) | | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | -| `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | +| `--data_path` | Path to data file (.ptd) for delegate data (required for Metal/CUDA) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 92e32ca30bf..a53f6920e3e 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -7,6 +7,7 @@ import tempfile import torch + import torchaudio from executorch.exir import ( EdgeCompileConfig, @@ -363,48 +364,102 @@ def export_all(model): return programs, metadata -def lower_to_executorch(programs, metadata=None, backend="portable"): +def _create_xnnpack_partitioners(programs): + """Create XNNPACK partitioners for all programs except preprocessor.""" + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) + + print("\nLowering to ExecuTorch with XNNPACK...") partitioner = {} + for key in programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + partitioner[key] = [XnnpackPartitioner()] + return partitioner, programs + + +def _linear_bias_decomposition(input, weight, bias=None): + """Decompose linear with bias into matmul + add.""" + # linear(input, weight) = input @ weight.T + # Use matmul instead of mm to handle batched inputs (3D+) + weight_t = torch.ops.aten.t.default(weight) + out = torch.ops.aten.matmul.default(input, weight_t) + if bias is not None: + return torch.ops.aten.add.Tensor(out, bias) + return out + + +def _create_metal_partitioners(programs): + """Create Metal partitioners for all programs except preprocessor.""" + from executorch.backends.apple.metal.metal_backend import MetalBackend + from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner + + print("\nLowering to ExecuTorch with Metal...") + + # Run decompositions for non-preprocessor programs + updated_programs = {} + for key, ep in programs.items(): + # print(f"Running decompositions for {key}") + # print(ep.graph_module) + if key != "preprocessor": + updated_programs[key] = ep.run_decompositions( + {torch.ops.aten.linear.default: _linear_bias_decomposition} + ) + else: + updated_programs[key] = ep - if backend == "xnnpack": - from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, - ) + partitioner = {} + for key in updated_programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [MetalBackend.generate_method_name_compile_spec(key)] + partitioner[key] = [MetalPartitioner(compile_specs)] + return partitioner, updated_programs + + +def _create_cuda_partitioners(programs, is_windows=False): + """Create CUDA partitioners for all programs except preprocessor.""" + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir.backend.compile_spec_schema import CompileSpec + from torch._inductor.decomposition import conv1d_to_conv2d + + print(f"\nLowering to ExecuTorch with CUDA{' (Windows)' if is_windows else ''}...") + + # Run decompositions for non-preprocessor programs + updated_programs = {} + for key, ep in programs.items(): + if key != "preprocessor": + updated_programs[key] = ep.run_decompositions( + {torch.ops.aten.conv1d.default: conv1d_to_conv2d} + ) + else: + updated_programs[key] = ep - print("\nLowering to ExecuTorch with XNNPACK...") - for key in programs.keys(): - if key == "preprocessor": - partitioner[key] = [] - else: - partitioner[key] = [XnnpackPartitioner()] + partitioner = {} + for key in updated_programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + if is_windows: + compile_specs.append(CompileSpec("platform", "windows".encode("utf-8"))) + partitioner[key] = [CudaPartitioner(compile_specs)] + return partitioner, updated_programs - elif backend in ("cuda", "cuda-windows"): - from executorch.backends.cuda.cuda_backend import CudaBackend - from executorch.backends.cuda.cuda_partitioner import CudaPartitioner - from executorch.exir.backend.compile_spec_schema import CompileSpec - from torch._inductor.decomposition import conv1d_to_conv2d - print( - f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}..." +def lower_to_executorch(programs, metadata=None, backend="portable"): + if backend == "xnnpack": + partitioner, programs = _create_xnnpack_partitioners(programs) + elif backend == "metal": + partitioner, programs = _create_metal_partitioners(programs) + elif backend in ("cuda", "cuda-windows"): + partitioner, programs = _create_cuda_partitioners( + programs, is_windows=(backend == "cuda-windows") ) - - for key, ep in programs.items(): - if key != "preprocessor": - programs[key] = ep.run_decompositions( - {torch.ops.aten.conv1d.default: conv1d_to_conv2d} - ) - - for key in programs.keys(): - if key == "preprocessor": - partitioner[key] = [] - else: - compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] - if backend == "cuda-windows": - compile_specs.append( - CompileSpec("platform", "windows".encode("utf-8")) - ) - partitioner[key] = [CudaPartitioner(compile_specs)] - else: print("\nLowering to ExecuTorch...") partitioner = [] @@ -442,11 +497,22 @@ def main(): "--backend", type=str, default="portable", - choices=["portable", "xnnpack", "cuda", "cuda-windows"], + choices=["portable", "xnnpack", "metal", "cuda", "cuda-windows"], help="Backend for acceleration (default: portable)", ) + parser.add_argument( + "--dtype", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16"], + help="Model dtype for Metal/CUDA backends (default: fp32)", + ) args = parser.parse_args() + # Validate dtype for Metal backend + if args.backend == "metal" and args.dtype == "fp16": + parser.error("Metal backend only supports fp32 and bf16, not fp16") + os.makedirs(args.output_dir, exist_ok=True) print("Extracting tokenizer...") @@ -455,6 +521,14 @@ def main(): print("Loading model...") model = load_model() + # Convert model to specified dtype for Metal/CUDA backends + if args.dtype == "bf16": + print("Converting model to bfloat16...") + model = model.to(torch.bfloat16) + elif args.dtype == "fp16": + print("Converting model to float16...") + model = model.to(torch.float16) + print("\nExporting components...") programs, metadata = export_all(model) From acc344bf31ad2716a064f46302a51e2848bbac78 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 14 Jan 2026 09:29:59 -0500 Subject: [PATCH 2/2] Debug patches --- backends/apple/metal/runtime/shims/et_metal.h | 1 + .../apple/metal/runtime/shims/et_metal.mm | 78 +++++- .../apple/metal/runtime/shims/et_metal_ops.mm | 26 +- examples/models/parakeet/CMakePresets.json | 5 +- examples/models/parakeet/main.cpp | 224 ++++++++++++++++++ 5 files changed, 329 insertions(+), 5 deletions(-) diff --git a/backends/apple/metal/runtime/shims/et_metal.h b/backends/apple/metal/runtime/shims/et_metal.h index e4d71fed72e..2f799a44437 100644 --- a/backends/apple/metal/runtime/shims/et_metal.h +++ b/backends/apple/metal/runtime/shims/et_metal.h @@ -380,6 +380,7 @@ int metal_copy_memory( bool dst_is_device); void metal_cleanup_resources(); void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer); +void metal_log_buffer_stats(); // Helper functions to access Metal objects MTLDevice_t get_metal_device(); diff --git a/backends/apple/metal/runtime/shims/et_metal.mm b/backends/apple/metal/runtime/shims/et_metal.mm index 4f4464a534c..6a5a6bc4706 100644 --- a/backends/apple/metal/runtime/shims/et_metal.mm +++ b/backends/apple/metal/runtime/shims/et_metal.mm @@ -17,6 +17,8 @@ #include #include #include +#include +#include namespace executorch { namespace backends { @@ -113,6 +115,25 @@ void metal_cleanup_resources() { } } +// New function to log Metal buffer statistics +extern "C" void metal_log_buffer_stats() { + ET_LOG(Info, "Metal buffer map size: %zu buffers", ptr_to_mtl_buffer.size()); + // Log first few and last few buffer pointers for pattern analysis + if (ptr_to_mtl_buffer.size() > 0 && ptr_to_mtl_buffer.size() <= 10) { + for (const auto& pair : ptr_to_mtl_buffer) { + id buf = pair.second; + ET_LOG(Info, " Buffer: ptr=%p, mtl_length=%zu", pair.first, (size_t)[buf length]); + } + } else if (ptr_to_mtl_buffer.size() > 10) { + // Just log count and total size + size_t total_size = 0; + for (const auto& pair : ptr_to_mtl_buffer) { + total_size += (size_t)[pair.second length]; + } + ET_LOG(Info, " Total buffer size: %zu bytes", total_size); + } +} + void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer) { id device = get_metal_device(); id subBuffer = [device newBufferWithBytesNoCopy:ptr @@ -345,9 +366,51 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev void* data_ptr = tensor.mutable_data_ptr(); size_t totalSize = tensor.numel() * tensor.element_size(); + int64_t numel = tensor.numel(); auto it = ptr_to_mtl_buffer.find(data_ptr); - if (it != ptr_to_mtl_buffer.end()) { + bool isGpuBuffer = (it != ptr_to_mtl_buffer.end()); + + // DEBUG: Print critical tensor info + const char* bufferType = isGpuBuffer ? "GPU" : "CPU"; + ET_LOG(Info, "setArg[%u]: %s buffer, numel=%lld, totalSize=%zu, ptr=%p", idx, bufferType, numel, totalSize, data_ptr); + + // DEBUG: For float tensors with position 1930, print critical values + if (tensor.scalar_type() == exec_aten::ScalarType::Float && numel > 1930) { + float* float_data = static_cast(data_ptr); + ET_LOG(Info, " setArg[%u] at [10]=%f, [650]=%f, [1290]=%f, [1930]=%f", + idx, float_data[10], float_data[650], float_data[1290], float_data[1930]); + // Check for NaN/Inf at all critical positions + if (std::isnan(float_data[10]) || std::isinf(float_data[10])) { + ET_LOG(Error, " setArg[%u]: NaN/Inf detected at position 10!", idx); + } + if (std::isnan(float_data[1930]) || std::isinf(float_data[1930])) { + ET_LOG(Error, " setArg[%u]: NaN/Inf detected at position 1930!", idx); + } + } + // DEBUG: For c_old tensor (numel=1280), check position 650 (layer 1 cell state) + if (tensor.scalar_type() == exec_aten::ScalarType::Float && numel == 1280) { + float* float_data = static_cast(data_ptr); + ET_LOG(Info, " setArg[%u] c_old at [10]=%f, [650]=%f (layer1 pos10)", + idx, float_data[10], float_data[650]); + if (std::isnan(float_data[650]) || std::isinf(float_data[650])) { + ET_LOG(Error, " setArg[%u]: c_old NaN/Inf at position 650!", idx); + } + } + // DEBUG: For output buffers (numel=640), check position 10 (layer 1 position after concat) + if (tensor.scalar_type() == exec_aten::ScalarType::Float && numel == 640) { + float* float_data = static_cast(data_ptr); + ET_LOG(Info, " setArg[%u] buf640 at [10]=%f, [100]=%f, [639]=%f", + idx, float_data[10], float_data[100], float_data[639]); + if (std::isnan(float_data[10]) || std::isinf(float_data[10])) { + // Print the raw bits of the NaN to understand its source + uint32_t bits; + std::memcpy(&bits, &float_data[10], sizeof(bits)); + ET_LOG(Error, " setArg[%u]: buf640 NaN/Inf at position 10! raw_bits=0x%08x", idx, bits); + } + } + + if (isGpuBuffer) { // Use existing Metal buffer id mtlBuffer = it->second; [encoder_ setBuffer:mtlBuffer offset:0 atIndex:idx]; @@ -457,6 +520,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev [encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize]; ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingle: Dispatched with length %llu, group size %llu", length, actualGroupSize); + // DEBUG: Force synchronization after each kernel dispatch to debug NaN issue + ETMetalStream* stream = getCurrentMetalStream(); + stream->synchronize(SyncType::COMMIT_AND_WAIT); + + // DEBUG: For length=640 (LSTM output), check if NaN was produced + if (length == 640) { + // Check the last two buffers that were set (output buffers at index 0 and 1) + // We need to read from the buffers after sync + ET_LOG(Info, "dispatchSingle: LSTM kernel completed (length=%llu), checking outputs...", length); + } } void ETMetalKernelFunction::dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size) { @@ -933,6 +1006,9 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev executionDescriptor:nil]; } }); + + // Apply the requested synchronization type + synchronize(syncType); } // ======================= diff --git a/backends/apple/metal/runtime/shims/et_metal_ops.mm b/backends/apple/metal/runtime/shims/et_metal_ops.mm index 5b413728de5..127be0e27ac 100644 --- a/backends/apple/metal/runtime/shims/et_metal_ops.mm +++ b/backends/apple/metal/runtime/shims/et_metal_ops.mm @@ -20,6 +20,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -600,7 +601,8 @@ AOTITorchError aoti_torch_mps_mm_out( @try { // Use stream helper to encode and synchronize correctly - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + // Using COMMIT_AND_WAIT to force synchronous execution for debugging + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_WAIT); } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps_mm_out: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); @@ -609,6 +611,22 @@ AOTITorchError aoti_torch_mps_mm_out( ET_LOG(Debug, "aoti_torch_mps_mm_out: MPSGraph execution completed successfully"); + // DEBUG: Check for NaN in matmul output at critical positions + float* out_data = static_cast(out_tensor->mutable_data_ptr()); + int64_t out_numel = out_tensor->numel(); + bool has_nan = false; + for (int64_t i = 0; i < out_numel; i++) { + if (std::isnan(out_data[i]) || std::isinf(out_data[i])) { + ET_LOG(Error, "aoti_torch_mps_mm_out: NaN/Inf at output[%lld] = %f", i, out_data[i]); + has_nan = true; + if (i > 10) break; // Only log first few + } + } + // Check position 1930 specifically (output gate for position 10) + if (out_numel > 1930) { + ET_LOG(Info, "aoti_torch_mps_mm_out: output[1930] = %f (output gate pos 10)", out_data[1930]); + } + [selfData release]; [mat2Data release]; [outputData release]; @@ -908,7 +926,8 @@ AOTITorchError aoti_torch_mps_bmm_out( // Execute the batched matrix multiplication @try { - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + // Using COMMIT_AND_WAIT for debugging + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_WAIT); } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps_bmm_out: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); @@ -1343,7 +1362,8 @@ AOTITorchError aoti_torch_mps_convolution( @try { // Use stream helper to encode and synchronize correctly - stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); + // Using COMMIT_AND_WAIT to force synchronous execution for debugging + stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_AND_WAIT); } @catch (NSException *exception) { ET_LOG(Error, "aoti_torch_mps_convolution: NSException caught during executeMPSGraph: %s - %s", [[exception name] UTF8String], [[exception reason] UTF8String]); diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json index ea93d257ba7..29e7452aa59 100644 --- a/examples/models/parakeet/CMakePresets.json +++ b/examples/models/parakeet/CMakePresets.json @@ -34,7 +34,10 @@ "displayName": "Parakeet runner (Metal)", "inherits": ["parakeet-base"], "cacheVariables": { - "EXECUTORCH_BUILD_METAL": "ON" + "EXECUTORCH_BUILD_METAL": "ON", + "CMAKE_BUILD_TYPE": "Debug", + "EXECUTORCH_ENABLE_LOGGING": "ON", + "ET_MIN_LOG_LEVEL": "Info" }, "condition": { "lhs": "${hostSystemName}", diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 026f3911a3d..d5d73531fca 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -23,6 +24,14 @@ #include #include +// For Metal synchronization debugging +#if defined(__APPLE__) +#include +#endif + +// For AOTI tensor metadata cache cleanup +#include + DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); DEFINE_string( @@ -56,6 +65,12 @@ std::vector greedy_decode_executorch( std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; + // Debug counters + int64_t total_blank_count = 0; + int64_t total_iterations = 0; + int64_t max_consecutive_non_blank = 0; + int64_t current_consecutive_non_blank = 0; + // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] auto enc_sizes = encoder_output.sizes(); int64_t batch = enc_sizes[0]; @@ -121,6 +136,12 @@ std::vector greedy_decode_executorch( ET_LOG(Error, "decoder_predict (SOS) failed"); return hypothesis; } + + // Force GPU synchronization for SOS initialization +#if defined(__APPLE__) + executorch::backends::metal::synchronize_metal_stream(); +#endif + auto& init_outputs = decoder_init_result.get(); auto g_init = init_outputs[0].toTensor(); auto new_h_init = init_outputs[1].toTensor(); @@ -134,6 +155,9 @@ std::vector greedy_decode_executorch( new_c_init.const_data_ptr(), c_data.size() * sizeof(float)); + // Clear AOTI tensor metadata cache after SOS initialization + executorch::backends::aoti::cleanup_tensor_metadata(); + auto g_proj_result = model.execute( "joint_project_decoder", std::vector<::executorch::runtime::EValue>{g_init}); @@ -161,6 +185,29 @@ std::vector greedy_decode_executorch( for (int64_t d = 0; d < proj_dim; d++) { f_t_data[d] = f_proj_data[t * proj_dim + d]; } + + // Log encoder frame stats around critical time steps + if (t >= 248 && t <= 260) { + float f_sum = 0.0f, f_max = f_t_data[0], f_min = f_t_data[0]; + int nan_count = 0; + for (size_t i = 0; i < f_t_data.size(); i++) { + if (std::isnan(f_t_data[i]) || std::isinf(f_t_data[i])) { + nan_count++; + } + f_sum += f_t_data[i]; + f_max = std::max(f_max, f_t_data[i]); + f_min = std::min(f_min, f_t_data[i]); + } + ET_LOG( + Info, + "Encoder frame[t=%lld]: sum=%.4f, min=%.4f, max=%.4f, nan_inf=%d", + static_cast(t), + f_sum, + f_min, + f_max, + nan_count); + } + auto f_t = from_blob( f_t_data.data(), {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, @@ -204,17 +251,91 @@ std::vector greedy_decode_executorch( } int64_t dur = DURATIONS[dur_idx]; + // Update debug counters + total_iterations++; + + // Debug logging for first 20 iterations and when issues appear + bool should_log = hypothesis.size() < 20 || k >= vocab_size || + (hypothesis.size() >= 95 && hypothesis.size() <= 115); + if (should_log) { + ET_LOG( + Info, + "Decode[t=%lld, hyp_len=%zu]: k=%lld (blank=%lld), dur=%lld, " + "token_logit=%.4f, dur_logit=%.4f, symbols_on_frame=%lld", + static_cast(t), + hypothesis.size(), + static_cast(k), + static_cast(blank_id), + static_cast(dur), + max_token_logit, + max_dur_logit, + static_cast(symbols_on_frame)); + + // Also log the top 3 token logits to see distribution + std::vector> top_logits; + for (int64_t i = 0; i < num_token_classes; i++) { + top_logits.push_back({logits_data[i], i}); + } + std::sort(top_logits.begin(), top_logits.end(), std::greater<>()); + ET_LOG( + Info, + " Top3 tokens: [%lld]=%.3f, [%lld]=%.3f, [%lld]=%.3f, blank[%lld]=%.3f", + static_cast(top_logits[0].second), top_logits[0].first, + static_cast(top_logits[1].second), top_logits[1].first, + static_cast(top_logits[2].second), top_logits[2].first, + static_cast(blank_id), logits_data[blank_id]); + } + + // Warn if token is out of range + if (k > vocab_size) { + ET_LOG( + Error, + "Invalid token id %lld (vocab_size=%lld) at t=%lld", + static_cast(k), + static_cast(vocab_size), + static_cast(t)); + } + if (k == blank_id) { t += std::max(dur, (int64_t)1); symbols_on_frame = 0; + total_blank_count++; + // Track max consecutive non-blank before reset + if (current_consecutive_non_blank > max_consecutive_non_blank) { + max_consecutive_non_blank = current_consecutive_non_blank; + } + current_consecutive_non_blank = 0; } else { hypothesis.push_back(k); + current_consecutive_non_blank++; // Update decoder state std::vector token_data = {k}; auto token = from_blob( token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); + // Log input state BEFORE decoder_predict for critical iterations + if (hypothesis.size() >= 99 && hypothesis.size() <= 102) { + float h_sum_before = 0.0f, c_sum_before = 0.0f; + int nan_before = 0; + for (size_t i = 0; i < h_data.size(); i++) { + if (std::isnan(h_data[i]) || std::isinf(h_data[i])) nan_before++; + h_sum_before += h_data[i]; + } + for (size_t i = 0; i < c_data.size(); i++) { + if (std::isnan(c_data[i]) || std::isinf(c_data[i])) nan_before++; + c_sum_before += c_data[i]; + } + ET_LOG( + Info, + " BEFORE decoder_predict[hyp=%zu]: h_sum=%.4f, c_sum=%.4f, nan_inf=%d, token=%lld", + hypothesis.size(), + h_sum_before, + c_sum_before, + nan_before, + static_cast(k)); + } + auto decoder_result = model.execute( "decoder_predict", std::vector<::executorch::runtime::EValue>{token, h, c}); @@ -222,6 +343,13 @@ std::vector greedy_decode_executorch( ET_LOG(Error, "decoder_predict failed"); return hypothesis; } + + // Force GPU synchronization to flush any internal MPSGraph state + // This is a debugging workaround for potential MPSGraph caching issues +#if defined(__APPLE__) + executorch::backends::metal::synchronize_metal_stream(); +#endif + auto& outputs = decoder_result.get(); auto g = outputs[0].toTensor(); auto new_h = outputs[1].toTensor(); @@ -237,6 +365,80 @@ std::vector greedy_decode_executorch( new_c.const_data_ptr(), c_data.size() * sizeof(float)); + // CRASH ON NAN for lldb debugging + for (size_t i = 0; i < h_data.size(); i++) { + if (std::isnan(h_data[i]) || std::isinf(h_data[i])) { + ET_LOG(Error, "NaN/Inf detected in h_data[%zu] = %f at hyp=%zu", + i, h_data[i], hypothesis.size()); + __builtin_trap(); // Crash here for lldb + } + } + for (size_t i = 0; i < c_data.size(); i++) { + if (std::isnan(c_data[i]) || std::isinf(c_data[i])) { + ET_LOG(Error, "NaN/Inf detected in c_data[%zu] = %f at hyp=%zu", + i, c_data[i], hypothesis.size()); + // __builtin_trap(); // Crash here for lldb + } + } + + // Recreate tensor wrappers to avoid AOTI caching issues with tensor identity + h = from_blob( + h_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + c = from_blob( + c_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + + // Clear AOTI tensor metadata cache to prevent stale pointer issues + executorch::backends::aoti::cleanup_tensor_metadata(); + + // Check for NaN/Inf in LSTM state (first 20 tokens, around failure point, or when issues occur) + bool should_log_state = hypothesis.size() <= 20 || k >= vocab_size || + (hypothesis.size() >= 95 && hypothesis.size() <= 115); + if (should_log_state) { + // Log Metal buffer statistics to check for buffer accumulation +#if defined(__APPLE__) + executorch::backends::metal::metal_log_buffer_stats(); +#endif + // Log pointer addresses to detect aliasing + ET_LOG( + Info, + " Buffer ptrs[hyp=%zu]: h_data=%p, c_data=%p, new_h=%p, new_c=%p", + hypothesis.size(), + static_cast(h_data.data()), + static_cast(c_data.data()), + static_cast(new_h.const_data_ptr()), + static_cast(new_c.const_data_ptr())); + float h_sum = 0.0f, c_sum = 0.0f; + int nan_count = 0; + for (size_t i = 0; i < h_data.size(); i++) { + if (std::isnan(h_data[i]) || std::isinf(h_data[i])) { + nan_count++; + } + h_sum += h_data[i]; + } + for (size_t i = 0; i < c_data.size(); i++) { + if (std::isnan(c_data[i]) || std::isinf(c_data[i])) { + nan_count++; + } + c_sum += c_data[i]; + } + // Always log since we're in should_log_state block + ET_LOG( + Info, + " LSTM state[hyp=%zu]: h_sum=%.4f, c_sum=%.4f, nan_inf_count=%d", + hypothesis.size(), + h_sum, + c_sum, + nan_count); + } + // Project decoder output auto proj_dec_result = model.execute( "joint_project_decoder", @@ -256,6 +458,12 @@ std::vector greedy_decode_executorch( if (dur == 0) { symbols_on_frame++; if (symbols_on_frame >= max_symbols_per_step) { + // Log when hitting the limit - this might indicate a problem + ET_LOG( + Info, + " Hit max_symbols_per_step at t=%lld, hyp_len=%zu, forcing advance", + static_cast(t), + hypothesis.size()); t++; symbols_on_frame = 0; } @@ -265,6 +473,22 @@ std::vector greedy_decode_executorch( } } + // Final update for max consecutive + if (current_consecutive_non_blank > max_consecutive_non_blank) { + max_consecutive_non_blank = current_consecutive_non_blank; + } + + // Summary statistics + ET_LOG( + Info, + "Decode summary: total_iterations=%lld, tokens=%zu, blanks=%lld, " + "max_consecutive_non_blank=%lld, encoder_len=%lld", + static_cast(total_iterations), + hypothesis.size(), + static_cast(total_blank_count), + static_cast(max_consecutive_non_blank), + static_cast(encoder_len)); + return hypothesis; }