Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,23 +218,24 @@ AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_get_storage_size(Tensor* tensor, int64_t* ret_size) {
(void)tensor;
(void)ret_size;
throw std::runtime_error("Not implemented");
throw std::runtime_error("Not implemented: aoti_torch_get_storage_size");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone_preserve_strides(Tensor* self, Tensor** ret_new_tensor) {
(void)self;
(void)ret_new_tensor;
throw std::runtime_error("Not implemented");
throw std::runtime_error(
"Not implemented: aoti_torch_clone_preserve_strides");
return Error::Internal;
}

AOTI_SHIM_EXPORT AOTITorchError
aoti_torch_clone(Tensor* self, Tensor** ret_new_tensor) {
(void)self;
(void)ret_new_tensor;
throw std::runtime_error("Not implemented");
throw std::runtime_error("Not implemented: aoti_torch_clone");
return Error::Internal;
}

Expand All @@ -257,7 +258,8 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
(void)device_type;
(void)device_index;
(void)ret_new_tensor;
throw std::runtime_error("Not implemented");
throw std::runtime_error(
"Not implemented: aoti_torch_create_tensor_from_blob");
return Error::Internal;
}

Expand Down
2 changes: 1 addition & 1 deletion backends/apple/metal/metal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_device_name(cls) -> str:
@classmethod
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
return {
"aoti_torch_mps_addmm_out": None,
"aoti_torch_mps_bmm_out": None,
"aoti_torch_mps_convolution": None,
"aoti_torch_mps_mm_out": None,
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
Expand Down
2 changes: 2 additions & 0 deletions backends/apple/metal/runtime/shims/et_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ int metal_copy_memory(
bool src_is_device,
bool dst_is_device);
void metal_cleanup_resources();
void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer);
void metal_log_buffer_stats();

// Helper functions to access Metal objects
MTLDevice_t get_metal_device();
Expand Down
90 changes: 89 additions & 1 deletion backends/apple/metal/runtime/shims/et_metal.mm
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <algorithm>
#include <optional>
#include <exception>
#include <cmath>
#include <cstring>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -113,6 +115,37 @@ void metal_cleanup_resources() {
}
}

// New function to log Metal buffer statistics
extern "C" void metal_log_buffer_stats() {
ET_LOG(Info, "Metal buffer map size: %zu buffers", ptr_to_mtl_buffer.size());
// Log first few and last few buffer pointers for pattern analysis
if (ptr_to_mtl_buffer.size() > 0 && ptr_to_mtl_buffer.size() <= 10) {
for (const auto& pair : ptr_to_mtl_buffer) {
id<MTLBuffer> buf = pair.second;
ET_LOG(Info, " Buffer: ptr=%p, mtl_length=%zu", pair.first, (size_t)[buf length]);
}
} else if (ptr_to_mtl_buffer.size() > 10) {
// Just log count and total size
size_t total_size = 0;
for (const auto& pair : ptr_to_mtl_buffer) {
total_size += (size_t)[pair.second length];
}
ET_LOG(Info, " Total buffer size: %zu bytes", total_size);
}
}

void metal_buffer_nocopy(void* ptr, size_t nbytes, bool map_ptr_to_buffer) {
id<MTLDevice> device = get_metal_device();
id<MTLBuffer> subBuffer = [device newBufferWithBytesNoCopy:ptr
length:nbytes
options:MTLResourceCPUCacheModeWriteCombined | MTLResourceStorageModeShared
deallocator:nil];

if (map_ptr_to_buffer) {
ptr_to_mtl_buffer[ptr] = subBuffer; // Map contents to buffer
}
}

bool metal_is_device_pointer(void* ptr) {
return ptr_to_mtl_buffer.find(ptr) != ptr_to_mtl_buffer.end();
}
Expand Down Expand Up @@ -333,9 +366,51 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev

void* data_ptr = tensor.mutable_data_ptr();
size_t totalSize = tensor.numel() * tensor.element_size();
int64_t numel = tensor.numel();

auto it = ptr_to_mtl_buffer.find(data_ptr);
if (it != ptr_to_mtl_buffer.end()) {
bool isGpuBuffer = (it != ptr_to_mtl_buffer.end());

// DEBUG: Print critical tensor info
const char* bufferType = isGpuBuffer ? "GPU" : "CPU";
ET_LOG(Info, "setArg[%u]: %s buffer, numel=%lld, totalSize=%zu, ptr=%p", idx, bufferType, numel, totalSize, data_ptr);

// DEBUG: For float tensors with position 1930, print critical values
if (tensor.scalar_type() == exec_aten::ScalarType::Float && numel > 1930) {
float* float_data = static_cast<float*>(data_ptr);
ET_LOG(Info, " setArg[%u] at [10]=%f, [650]=%f, [1290]=%f, [1930]=%f",
idx, float_data[10], float_data[650], float_data[1290], float_data[1930]);
// Check for NaN/Inf at all critical positions
if (std::isnan(float_data[10]) || std::isinf(float_data[10])) {
ET_LOG(Error, " setArg[%u]: NaN/Inf detected at position 10!", idx);
}
if (std::isnan(float_data[1930]) || std::isinf(float_data[1930])) {
ET_LOG(Error, " setArg[%u]: NaN/Inf detected at position 1930!", idx);
}
}
// DEBUG: For c_old tensor (numel=1280), check position 650 (layer 1 cell state)
if (tensor.scalar_type() == exec_aten::ScalarType::Float && numel == 1280) {
float* float_data = static_cast<float*>(data_ptr);
ET_LOG(Info, " setArg[%u] c_old at [10]=%f, [650]=%f (layer1 pos10)",
idx, float_data[10], float_data[650]);
if (std::isnan(float_data[650]) || std::isinf(float_data[650])) {
ET_LOG(Error, " setArg[%u]: c_old NaN/Inf at position 650!", idx);
}
}
// DEBUG: For output buffers (numel=640), check position 10 (layer 1 position after concat)
if (tensor.scalar_type() == exec_aten::ScalarType::Float && numel == 640) {
float* float_data = static_cast<float*>(data_ptr);
ET_LOG(Info, " setArg[%u] buf640 at [10]=%f, [100]=%f, [639]=%f",
idx, float_data[10], float_data[100], float_data[639]);
if (std::isnan(float_data[10]) || std::isinf(float_data[10])) {
// Print the raw bits of the NaN to understand its source
uint32_t bits;
std::memcpy(&bits, &float_data[10], sizeof(bits));
ET_LOG(Error, " setArg[%u]: buf640 NaN/Inf at position 10! raw_bits=0x%08x", idx, bits);
}
}

if (isGpuBuffer) {
// Use existing Metal buffer
id<MTLBuffer> mtlBuffer = it->second;
[encoder_ setBuffer:mtlBuffer offset:0 atIndex:idx];
Expand Down Expand Up @@ -445,6 +520,16 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
[encoder_ dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
ET_LOG(Debug, "ETMetalKernelFunction::dispatchSingle: Dispatched with length %llu, group size %llu", length, actualGroupSize);

// DEBUG: Force synchronization after each kernel dispatch to debug NaN issue
ETMetalStream* stream = getCurrentMetalStream();
stream->synchronize(SyncType::COMMIT_AND_WAIT);

// DEBUG: For length=640 (LSTM output), check if NaN was produced
if (length == 640) {
// Check the last two buffers that were set (output buffers at index 0 and 1)
// We need to read from the buffers after sync
ET_LOG(Info, "dispatchSingle: LSTM kernel completed (length=%llu), checking outputs...", length);
}
}

void ETMetalKernelFunction::dispatchSingleWithGroupSize(uint64_t length, uint64_t group_size) {
Expand Down Expand Up @@ -921,6 +1006,9 @@ int metal_copy_memory(void* dst, const void* src, size_t nbytes, bool src_is_dev
executionDescriptor:nil];
}
});

// Apply the requested synchronization type
synchronize(syncType);
}

// =======================
Expand Down
10 changes: 10 additions & 0 deletions backends/apple/metal/runtime/shims/et_metal_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ AOTITorchError aoti_torch_mps_mm_out(
AOTITensorHandle self,
AOTITensorHandle mat2);

/**
* ExecutorTorch implementation of aoti_torch_mps_bmm_out.
* Performs batched matrix multiplication: out = self @ mat2
* All tensors must be 3-D with matching batch dimensions.
*/
AOTITorchError aoti_torch_mps_bmm_out(
AOTITensorHandle out,
AOTITensorHandle self,
AOTITensorHandle mat2);

/**
* ExecutorTorch implementation of aoti_torch_mps_convolution.
* Performs 2D convolution operation - matches PyTorch AOTI signature
Expand Down
Loading
Loading