diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..7d74a5bc81 --- /dev/null +++ b/.clang-format @@ -0,0 +1,11 @@ +# clang-format configuration for the react-native-executorch C/C++/Objective-C +# sources. Based on the LLVM style (the project's prior implicit default) with a +# wider column limit so clang-format stops aggressively wrapping function +# signatures, chained ternaries and string concatenations. See issue #1217. +# +# Vendored libraries (common/ada, common/pfft, common/runner) and the generated +# ErrorCodes.h are excluded from formatting via .clang-format-ignore. +BasedOnStyle: LLVM +Standard: c++20 +ColumnLimit: 100 +InsertNewlineAtEOF: true diff --git a/.clang-format-ignore b/.clang-format-ignore new file mode 100644 index 0000000000..4cc2b29426 --- /dev/null +++ b/.clang-format-ignore @@ -0,0 +1,12 @@ +# Files clang-format must skip. Patterns are gitignore-style, relative to this +# file. This keeps `clang-format -i` (and the pre-commit hook) from reformatting +# vendored third-party code and generated sources. + +# Vendored libraries that live outside third-party/. +packages/react-native-executorch/common/ada/** +packages/react-native-executorch/common/pfft/** +packages/react-native-executorch/common/runner/** + +# Generated verbatim by `yarn codegen:errors` and verified in CI with a raw +# diff; reformatting would rewrap its comments and break that check. +packages/react-native-executorch/common/rnexecutorch/ErrorCodes.h diff --git a/packages/react-native-executorch-webrtc/android/src/main/cpp/FrameProcessorBridge.cpp b/packages/react-native-executorch-webrtc/android/src/main/cpp/FrameProcessorBridge.cpp index 1b127564a9..abe62df559 100644 --- a/packages/react-native-executorch-webrtc/android/src/main/cpp/FrameProcessorBridge.cpp +++ b/packages/react-native-executorch-webrtc/android/src/main/cpp/FrameProcessorBridge.cpp @@ -63,8 +63,7 @@ cv::Mat unrotateMat(const cv::Mat &src, int32_t rotation) { extern "C" { -JNIEXPORT jlong JNICALL -Java_com_executorch_webrtc_ExecutorchFrameProcessor_loadModel( +JNIEXPORT jlong JNICALL Java_com_executorch_webrtc_ExecutorchFrameProcessor_loadModel( JNIEnv *env, jobject thiz, jstring modelPath) { const char *pathChars = env->GetStringUTFChars(modelPath, nullptr); if (pathChars == nullptr) { @@ -82,8 +81,8 @@ Java_com_executorch_webrtc_ExecutorchFrameProcessor_loadModel( std::vector allClasses = {"foreground", "background"}; auto handle = std::make_unique(); - handle->segmentation = std::make_unique( - path, normMean, normStd, allClasses, nullptr); + handle->segmentation = + std::make_unique(path, normMean, normStd, allClasses, nullptr); auto inputShapes = handle->segmentation->getAllInputShapes(); if (!inputShapes.empty() && inputShapes[0].size() >= 4) { @@ -99,10 +98,9 @@ Java_com_executorch_webrtc_ExecutorchFrameProcessor_loadModel( } } -JNIEXPORT jbyteArray JNICALL -Java_com_executorch_webrtc_ExecutorchFrameProcessor_runSegmentation( - JNIEnv *env, jobject thiz, jlong handlePtr, jbyteArray rgbaData, jint width, - jint height, jint rotation) { +JNIEXPORT jbyteArray JNICALL Java_com_executorch_webrtc_ExecutorchFrameProcessor_runSegmentation( + JNIEnv *env, jobject thiz, jlong handlePtr, jbyteArray rgbaData, jint width, jint height, + jint rotation) { if (handlePtr == 0) { return nullptr; } @@ -127,15 +125,13 @@ Java_com_executorch_webrtc_ExecutorchFrameProcessor_runSegmentation( pixelData.scalarType = executorch::aten::ScalarType::Byte; std::set> classesOfInterest = {"foreground"}; - auto result = handle->segmentation->generateFromPixels( - pixelData, classesOfInterest, false); + auto result = handle->segmentation->generateFromPixels(pixelData, classesOfInterest, false); cv::Mat mask; if (result.classBuffers && result.classBuffers->count("foreground")) { auto &fgBuffer = result.classBuffers->at("foreground"); auto *fgData = reinterpret_cast(fgBuffer->data()); - mask = cv::Mat(handle->modelHeight, handle->modelWidth, CV_32FC1, fgData) - .clone(); + mask = cv::Mat(handle->modelHeight, handle->modelWidth, CV_32FC1, fgData).clone(); } else { LOGE("No foreground mask in result"); env->ReleaseByteArrayElements(rgbaData, rgbaPtr, JNI_ABORT); @@ -151,16 +147,15 @@ Java_com_executorch_webrtc_ExecutorchFrameProcessor_runSegmentation( handle->previousMask = mask.clone(); handle->hasHistory = true; } else { - cv::addWeighted(handle->previousMask, EMA_ALPHA, mask, 1.0f - EMA_ALPHA, - 0, handle->previousMask); + cv::addWeighted(handle->previousMask, EMA_ALPHA, mask, 1.0f - EMA_ALPHA, 0, + handle->previousMask); mask = handle->previousMask.clone(); } cv::Mat maskRotated = unrotateMat(mask, rotation); cv::Mat maskResized; - cv::resize(maskRotated, maskResized, cv::Size(width, height), 0, 0, - cv::INTER_LINEAR); + cv::resize(maskRotated, maskResized, cv::Size(width, height), 0, 0, cv::INTER_LINEAR); cv::Mat maskBytes; maskResized.convertTo(maskBytes, CV_8UC1, 255.0); @@ -171,8 +166,7 @@ Java_com_executorch_webrtc_ExecutorchFrameProcessor_runSegmentation( env->ReleaseByteArrayElements(rgbaData, rgbaPtr, JNI_ABORT); return nullptr; } - env->SetByteArrayRegion(output, 0, maskSize, - reinterpret_cast(maskBytes.data)); + env->SetByteArrayRegion(output, 0, maskSize, reinterpret_cast(maskBytes.data)); env->ReleaseByteArrayElements(rgbaData, rgbaPtr, JNI_ABORT); return output; @@ -183,8 +177,7 @@ Java_com_executorch_webrtc_ExecutorchFrameProcessor_runSegmentation( } } -JNIEXPORT void JNICALL -Java_com_executorch_webrtc_ExecutorchFrameProcessor_unloadModel( +JNIEXPORT void JNICALL Java_com_executorch_webrtc_ExecutorchFrameProcessor_unloadModel( JNIEnv *env, jobject thiz, jlong handlePtr) { if (handlePtr == 0) { return; diff --git a/packages/react-native-executorch-webrtc/ios/ExecutorchFrameProcessor.mm b/packages/react-native-executorch-webrtc/ios/ExecutorchFrameProcessor.mm index 2fd1d2d1af..1f06447bd0 100644 --- a/packages/react-native-executorch-webrtc/ios/ExecutorchFrameProcessor.mm +++ b/packages/react-native-executorch-webrtc/ios/ExecutorchFrameProcessor.mm @@ -72,9 +72,8 @@ - (void)configureWithModelPath:(NSString *)modelPath { std::vector normMean = {}; std::vector normStd = {}; std::vector allClasses = {"foreground", "background"}; - seg = std::make_unique( - std::string([modelPath UTF8String]), normMean, normStd, allClasses, - nullptr); + seg = std::make_unique(std::string([modelPath UTF8String]), normMean, + normStd, allClasses, nullptr); auto inputShapes = seg->getAllInputShapes(); if (!inputShapes.empty() && inputShapes[0].size() >= 4) { h = inputShapes[0][inputShapes[0].size() - 2]; @@ -181,8 +180,8 @@ - (RTCVideoFrame *)processFrameLocked:(RTCVideoFrame *)frame { } CVPixelBufferRef outputBuffer = NULL; - CVReturn poolStatus = CVPixelBufferPoolCreatePixelBuffer( - kCFAllocatorDefault, _outputPool, &outputBuffer); + CVReturn poolStatus = + CVPixelBufferPoolCreatePixelBuffer(kCFAllocatorDefault, _outputPool, &outputBuffer); if (poolStatus != kCVReturnSuccess) { if (ownsInput) { CVPixelBufferRelease(inputPixelBuffer); @@ -193,15 +192,13 @@ - (RTCVideoFrame *)processFrameLocked:(RTCVideoFrame *)frame { CIImage *original = [CIImage imageWithCVPixelBuffer:inputPixelBuffer]; CGFloat scaleX = (CGFloat)width / maskImage.extent.size.width; CGFloat scaleY = (CGFloat)height / maskImage.extent.size.height; - CIImage *scaledMask = [maskImage - imageByApplyingTransform:CGAffineTransformMakeScale(scaleX, scaleY)]; + CIImage *scaledMask = + [maskImage imageByApplyingTransform:CGAffineTransformMakeScale(scaleX, scaleY)]; CIFilter *blurFilter = [CIFilter filterWithName:@"CIGaussianBlur"]; - [blurFilter setValue:[original imageByClampingToExtent] - forKey:kCIInputImageKey]; + [blurFilter setValue:[original imageByClampingToExtent] forKey:kCIInputImageKey]; [blurFilter setValue:@(_blurRadius) forKey:kCIInputRadiusKey]; - CIImage *blurred = - [blurFilter.outputImage imageByCroppingToRect:original.extent]; + CIImage *blurred = [blurFilter.outputImage imageByCroppingToRect:original.extent]; CIFilter *blendFilter = [CIFilter filterWithName:@"CIBlendWithMask"]; [blendFilter setValue:original forKey:kCIInputImageKey]; @@ -216,12 +213,10 @@ - (RTCVideoFrame *)processFrameLocked:(RTCVideoFrame *)frame { colorSpace:colorSpace]; CGColorSpaceRelease(colorSpace); - RTCCVPixelBuffer *rtcBuffer = - [[RTCCVPixelBuffer alloc] initWithPixelBuffer:outputBuffer]; - RTCVideoFrame *outputFrame = - [[RTCVideoFrame alloc] initWithBuffer:rtcBuffer - rotation:frame.rotation - timeStampNs:frame.timeStampNs]; + RTCCVPixelBuffer *rtcBuffer = [[RTCCVPixelBuffer alloc] initWithPixelBuffer:outputBuffer]; + RTCVideoFrame *outputFrame = [[RTCVideoFrame alloc] initWithBuffer:rtcBuffer + rotation:frame.rotation + timeStampNs:frame.timeStampNs]; CVPixelBufferRelease(outputBuffer); if (ownsInput) { @@ -271,10 +266,9 @@ - (CIImage *)generateMaskLockedForPixelBuffer:(CVPixelBufferRef)pixelBuffer yMat.copyTo(yuvMat(cv::Rect(0, 0, (int)width, (int)height))); std::vector uvChannels; cv::split(uvMat, uvChannels); - uvChannels[0].copyTo( - yuvMat(cv::Rect(0, (int)height, (int)width / 2, (int)height / 2))); - uvChannels[1].copyTo(yuvMat(cv::Rect((int)width / 2, (int)height, - (int)width / 2, (int)height / 2))); + uvChannels[0].copyTo(yuvMat(cv::Rect(0, (int)height, (int)width / 2, (int)height / 2))); + uvChannels[1].copyTo( + yuvMat(cv::Rect((int)width / 2, (int)height, (int)width / 2, (int)height / 2))); cv::cvtColor(yuvMat, rgbMat, cv::COLOR_YUV2RGB_I420); } else { CVPixelBufferUnlockBaseAddress(pixelBuffer, kCVPixelBufferLock_ReadOnly); @@ -302,8 +296,7 @@ - (CIImage *)generateMaskLockedForPixelBuffer:(CVPixelBufferRef)pixelBuffer pixelData.scalarType = executorch::aten::ScalarType::Byte; std::set> classesOfInterest = {"foreground"}; - auto result = - _segmentation->generateFromPixels(pixelData, classesOfInterest, false); + auto result = _segmentation->generateFromPixels(pixelData, classesOfInterest, false); if (result.classBuffers && result.classBuffers->count("foreground")) { auto &fgBuffer = result.classBuffers->at("foreground"); @@ -331,8 +324,7 @@ - (CIImage *)generateMaskLockedForPixelBuffer:(CVPixelBufferRef)pixelBuffer if (_previousMask.empty() || _previousMask.size() != maskRotated.size()) { _previousMask = maskRotated.clone(); } else { - cv::addWeighted(maskRotated, EMA_ALPHA, _previousMask, 1.0f - EMA_ALPHA, 0, - maskRotated); + cv::addWeighted(maskRotated, EMA_ALPHA, _previousMask, 1.0f - EMA_ALPHA, 0, maskRotated); _previousMask = maskRotated.clone(); } @@ -380,9 +372,9 @@ - (CVPixelBufferRef)createPixelBufferFromI420:(id)buffer { (id)kCVPixelBufferIOSurfacePropertiesKey : @{}, }; CVPixelBufferRef pixelBuffer = NULL; - CVReturn result = CVPixelBufferCreate( - kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA, - (__bridge CFDictionaryRef)attrs, &pixelBuffer); + CVReturn result = + CVPixelBufferCreate(kCFAllocatorDefault, width, height, kCVPixelFormatType_32BGRA, + (__bridge CFDictionaryRef)attrs, &pixelBuffer); if (result != kCVReturnSuccess) { return NULL; } @@ -398,8 +390,7 @@ - (CVPixelBufferRef)createPixelBufferFromI420:(id)buffer { memcpy(uvDst + row * uvWidth, dataU + row * strideU, uvWidth); } for (int row = 0; row < uvHeight; row++) { - memcpy(uvDst + uvHeight * uvWidth + row * uvWidth, dataV + row * strideV, - uvWidth); + memcpy(uvDst + uvHeight * uvWidth + row * uvWidth, dataV + row * strideV, uvWidth); } cv::Mat bgraMat; @@ -425,9 +416,9 @@ - (CVPixelBufferRef)createGrayscalePixelBuffer:(cv::Mat &)grayMat { (id)kCVPixelBufferIOSurfacePropertiesKey : @{}, }; CVPixelBufferRef pixelBuffer = NULL; - CVReturn result = CVPixelBufferCreate( - kCFAllocatorDefault, width, height, kCVPixelFormatType_OneComponent8, - (__bridge CFDictionaryRef)attrs, &pixelBuffer); + CVReturn result = + CVPixelBufferCreate(kCFAllocatorDefault, width, height, kCVPixelFormatType_OneComponent8, + (__bridge CFDictionaryRef)attrs, &pixelBuffer); if (result != kCVReturnSuccess) { return NULL; } @@ -457,8 +448,7 @@ - (void)ensurePoolLockedForWidth:(size_t)width height:(size_t)height { (id)kCVPixelBufferIOSurfacePropertiesKey : @{}, (id)kCVPixelBufferMetalCompatibilityKey : @YES, }; - CVPixelBufferPoolCreate(kCFAllocatorDefault, NULL, - (__bridge CFDictionaryRef)attrs, &_outputPool); + CVPixelBufferPoolCreate(kCFAllocatorDefault, NULL, (__bridge CFDictionaryRef)attrs, &_outputPool); _poolWidth = width; _poolHeight = height; } diff --git a/packages/react-native-executorch-webrtc/ios/ExecutorchWebRTC.mm b/packages/react-native-executorch-webrtc/ios/ExecutorchWebRTC.mm index e4e8f5cea3..6d16b6a762 100644 --- a/packages/react-native-executorch-webrtc/ios/ExecutorchWebRTC.mm +++ b/packages/react-native-executorch-webrtc/ios/ExecutorchWebRTC.mm @@ -36,15 +36,11 @@ + (void)ensureRegistered { [_processor unloadModel]; } -RCT_EXPORT_METHOD(setBlurRadius : (double)radius) { - [_processor setBlurRadius:(float)radius]; -} +RCT_EXPORT_METHOD(setBlurRadius : (double)radius) { [_processor setBlurRadius:(float)radius]; } RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(isAvailable) { return @YES; } -RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(getProcessorName) { - return PROCESSOR_NAME; -} +RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(getProcessorName) { return PROCESSOR_NAME; } #pragma mark - Legacy API (for backward compatibility) @@ -54,8 +50,7 @@ + (void)ensureRegistered { [self initialize:modelPath]; } -RCT_EXPORT_METHOD(configureBackgroundBlur : (NSString *) - modelPath blurIntensity : (int)intensity) { +RCT_EXPORT_METHOD(configureBackgroundBlur : (NSString *)modelPath blurIntensity : (int)intensity) { [self initialize:modelPath]; [self setBlurRadius:intensity]; } diff --git a/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.cpp b/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.cpp index 1a637b15d1..f846a9ba49 100644 --- a/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.cpp +++ b/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.cpp @@ -12,16 +12,13 @@ JavaVM *java_machine; using namespace facebook::jni; ETInstallerModule::ETInstallerModule( - jni::alias_ref &jThis, - jsi::Runtime *jsiRuntime, + jni::alias_ref &jThis, jsi::Runtime *jsiRuntime, const std::shared_ptr &jsCallInvoker) - : javaPart_(make_global(jThis)), jsiRuntime_(jsiRuntime), - jsCallInvoker_(jsCallInvoker) {} + : javaPart_(make_global(jThis)), jsiRuntime_(jsiRuntime), jsCallInvoker_(jsCallInvoker) {} jni::local_ref ETInstallerModule::initHybrid( jni::alias_ref jThis, jlong jsContext, - jni::alias_ref - jsCallInvokerHolder) { + jni::alias_ref jsCallInvokerHolder) { auto jsCallInvoker = jsCallInvokerHolder->cthis()->getCallInvoker(); auto rnRuntime = reinterpret_cast(jsContext); return makeCxxInstance(jThis, rnRuntime, jsCallInvoker); @@ -30,8 +27,7 @@ jni::local_ref ETInstallerModule::initHybrid( void ETInstallerModule::registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ETInstallerModule::initHybrid), - makeNativeMethod("injectJSIBindings", - ETInstallerModule::injectJSIBindings), + makeNativeMethod("injectJSIBindings", ETInstallerModule::injectJSIBindings), }); } @@ -48,12 +44,11 @@ void ETInstallerModule::injectJSIBindings() { } } static jclass cls = javaClassStatic().get(); - static jmethodID method = env->GetStaticMethodID( - cls, "fetchByteDataFromUrl", "(Ljava/lang/String;)[B"); + static jmethodID method = + env->GetStaticMethodID(cls, "fetchByteDataFromUrl", "(Ljava/lang/String;)[B"); jstring jUrl = env->NewStringUTF(url.c_str()); - jbyteArray byteData = - (jbyteArray)env->CallStaticObjectMethod(cls, method, jUrl); + jbyteArray byteData = (jbyteArray)env->CallStaticObjectMethod(cls, method, jUrl); if (env->IsSameObject(byteData, NULL)) { throw std::runtime_error("Error fetching data from a url"); @@ -68,13 +63,12 @@ void ETInstallerModule::injectJSIBindings() { auto _isEmulator = isEmulator(); - RnExecutorchInstaller::injectJSIBindings(jsiRuntime_, jsCallInvoker_, - fetchDataByUrl, _isEmulator); + RnExecutorchInstaller::injectJSIBindings(jsiRuntime_, jsCallInvoker_, fetchDataByUrl, + _isEmulator); } } // namespace rnexecutorch JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *) { rnexecutorch::java_machine = vm; - return facebook::jni::initialize( - vm, [] { rnexecutorch::ETInstallerModule::registerNatives(); }); -} \ No newline at end of file + return facebook::jni::initialize(vm, [] { rnexecutorch::ETInstallerModule::registerNatives(); }); +} diff --git a/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.h b/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.h index 82531ac9d5..b61a571741 100644 --- a/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.h +++ b/packages/react-native-executorch/android/src/main/cpp/ETInstallerModule.h @@ -14,13 +14,11 @@ using namespace react; class ETInstallerModule : public jni::HybridClass { public: - static auto constexpr kJavaDescriptor = - "Lcom/swmansion/rnexecutorch/ETInstaller;"; + static auto constexpr kJavaDescriptor = "Lcom/swmansion/rnexecutorch/ETInstaller;"; static jni::local_ref initHybrid(jni::alias_ref jThis, jlong jsContext, - jni::alias_ref - jsCallInvokerHolder); + jni::alias_ref jsCallInvokerHolder); static void registerNatives(); @@ -33,10 +31,9 @@ class ETInstallerModule : public jni::HybridClass { jsi::Runtime *jsiRuntime_; std::shared_ptr jsCallInvoker_; - explicit ETInstallerModule( - jni::alias_ref &jThis, - jsi::Runtime *jsiRuntime, - const std::shared_ptr &jsCallInvoker); + explicit ETInstallerModule(jni::alias_ref &jThis, + jsi::Runtime *jsiRuntime, + const std::shared_ptr &jsCallInvoker); }; } // namespace rnexecutorch diff --git a/packages/react-native-executorch/android/src/main/cpp/EmulatorDetection.h b/packages/react-native-executorch/android/src/main/cpp/EmulatorDetection.h index 63cc17e398..0f1995f413 100644 --- a/packages/react-native-executorch/android/src/main/cpp/EmulatorDetection.h +++ b/packages/react-native-executorch/android/src/main/cpp/EmulatorDetection.h @@ -15,8 +15,7 @@ inline bool isEmulator() { std::string result; __system_property_read_callback( pi, - [](void *cookie, const char * /*__name*/, const char *value, - uint32_t /*__serial*/) { + [](void *cookie, const char * /*__name*/, const char *value, uint32_t /*__serial*/) { *static_cast(cookie) = value; }, &result); diff --git a/packages/react-native-executorch/common/rnexecutorch/Error.h b/packages/react-native-executorch/common/rnexecutorch/Error.h index edd69d4a07..d2ca4a45cc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/Error.h +++ b/packages/react-native-executorch/common/rnexecutorch/Error.h @@ -11,8 +11,7 @@ namespace rnexecutorch { -using ErrorVariant = - std::variant; +using ErrorVariant = std::variant; class RnExecutorchError : public std::runtime_error { public: @@ -22,9 +21,7 @@ class RnExecutorchError : public std::runtime_error { : std::runtime_error(message), errorCode(code) {} int32_t getNumericCode() const noexcept { - return std::visit( - [](auto &&arg) -> int32_t { return static_cast(arg); }, - errorCode); + return std::visit([](auto &&arg) -> int32_t { return static_cast(arg); }, errorCode); } bool isRnExecutorchError() const noexcept { @@ -50,26 +47,24 @@ inline std::string locationPrefix(const std::source_location &loc) { [[noreturn]] inline void throwNotLoaded(std::source_location loc = std::source_location::current()) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - locationPrefix(loc) + "Model not loaded (in: " + - loc.function_name() + ")"); + locationPrefix(loc) + "Model not loaded (in: " + loc.function_name() + + ")"); } template -inline void checkOkOrThrowForwardError( - const Result &result, - std::source_location loc = std::source_location::current()) { +inline void checkOkOrThrowForwardError(const Result &result, + std::source_location loc = std::source_location::current()) { if (!result.ok()) { - throw RnExecutorchError( - result.error(), locationPrefix(loc) + - "Forward pass failed (in: " + loc.function_name() + - "). Ensure the model input is correct."); + throw RnExecutorchError(result.error(), locationPrefix(loc) + + "Forward pass failed (in: " + loc.function_name() + + "). Ensure the model input is correct."); } } } // namespace detail } // namespace rnexecutorch -#define CHECK_OK_OR_THROW_FORWARD_ERROR(result) \ +#define CHECK_OK_OR_THROW_FORWARD_ERROR(result) \ ::rnexecutorch::detail::checkOkOrThrowForwardError(result) #define THROW_NOT_LOADED_ERROR() ::rnexecutorch::detail::throwNotLoaded() diff --git a/packages/react-native-executorch/common/rnexecutorch/Log.h b/packages/react-native-executorch/common/rnexecutorch/Log.h index bb17a53ec9..2161f8d332 100644 --- a/packages/react-native-executorch/common/rnexecutorch/Log.h +++ b/packages/react-native-executorch/common/rnexecutorch/Log.h @@ -68,8 +68,7 @@ concept ReadOnlySequencableTop = requires(const T &t) { } && HasEmpty && !Iterable; template -concept ReadOnlySequencable = - ReadOnlySequencableFront || ReadOnlySequencableTop; +concept ReadOnlySequencable = ReadOnlySequencableFront || ReadOnlySequencableTop; template concept MutableSequencable = ReadOnlySequencable && HasPop; @@ -89,15 +88,13 @@ template concept WeakPointer = requires(const T &a) { { a.lock() - } -> std::convertible_to< - std::shared_ptr>; // Verifies if a.lock() can - // convert to std::shared_ptr + } -> std::convertible_to>; // Verifies if a.lock() can + // convert to std::shared_ptr }; template -concept Fallback = - !Iterable && !Streamable && !SmartPointer && !WeakPointer && - !ReadOnlySequencable && !MutableSequencable; +concept Fallback = !Iterable && !Streamable && !SmartPointer && !WeakPointer && + !ReadOnlySequencable && !MutableSequencable; } // namespace concepts @@ -105,14 +102,11 @@ template requires concepts::Streamable && (!concepts::SmartPointer) void printElement(std::ostream &os, const T &value); -template -void printElement(std::ostream &os, const std::pair &p); +template void printElement(std::ostream &os, const std::pair &p); -template -void printElement(std::ostream &os, const char (&array)[N]); +template void printElement(std::ostream &os, const char (&array)[N]); -template -void printElement(std::ostream &os, T (&array)[N]); +template void printElement(std::ostream &os, T (&array)[N]); template requires concepts::Iterable && (!concepts::Streamable) @@ -128,27 +122,21 @@ template requires concepts::MutableSequencable void printElement(std::ostream &os, T &&container); -template -void printElement(std::ostream &os, const std::tuple &tpl); +template void printElement(std::ostream &os, const std::tuple &tpl); -template -void printElement(std::ostream &os, const SP &ptr); +template void printElement(std::ostream &os, const SP &ptr); -template -void printElement(std::ostream &os, const WP &ptr); +template void printElement(std::ostream &os, const WP &ptr); -template -void printElement(std::ostream &os, const std::optional &opt); +template void printElement(std::ostream &os, const std::optional &opt); -template -void printElement(std::ostream &os, const std::variant &var); +template void printElement(std::ostream &os, const std::variant &var); void printElement(std::ostream &os, const std::exception_ptr &exPtr); void printElement(std::ostream &os, const std::filesystem::path &path); -void printElement(std::ostream &os, - const std::filesystem::directory_iterator &dir_it); +void printElement(std::ostream &os, const std::filesystem::directory_iterator &dir_it); template void printElement(std::ostream &os, const UnsupportedArg &value); @@ -161,8 +149,7 @@ void printElement(std::ostream &os, const T &value) { os << value; } -template -void printElement(std::ostream &os, const std::pair &p) { +template void printElement(std::ostream &os, const std::pair &p) { os << "("; printElement(os, p.first); os << ", "; @@ -170,8 +157,7 @@ void printElement(std::ostream &os, const std::pair &p) { os << ")"; } -template -void printElement(std::ostream &os, const char (&array)[N]) { +template void printElement(std::ostream &os, const char (&array)[N]) { // Treats the input as a string up to length N, drop null termination if (N > 1) { os << std::string_view(array, N - 1); @@ -179,8 +165,7 @@ void printElement(std::ostream &os, const char (&array)[N]) { } // A special function for C-style arrays deducing size via template -template -void printElement(std::ostream &os, T (&array)[N]) { +template void printElement(std::ostream &os, T (&array)[N]) { os << "["; for (std::size_t i = 0; i < N; ++i) { if (i > 0) { @@ -214,8 +199,7 @@ template void printSequencable(std::ostream &os, T &&container) { if (!isFirst) { os << ", "; } - low_level_log_implementation::printElement( - os, std::forward(element)); + low_level_log_implementation::printElement(os, std::forward(element)); isFirst = false; }; @@ -234,10 +218,9 @@ template void printSequencable(std::ostream &os, T &&container) { template requires concepts::ReadOnlySequencable void printElement(std::ostream &os, const T &container) { - T tempContainer = container; // Make a copy to preserve original container - printSequencable( - os, std::move(tempContainer)); // Use std::move since tempContainer won't - // be used again + T tempContainer = container; // Make a copy to preserve original container + printSequencable(os, std::move(tempContainer)); // Use std::move since tempContainer won't + // be used again } template @@ -246,8 +229,7 @@ void printElement(std::ostream &os, T &&container) { printSequencable(os, std::forward(container)); } -template -void printElement(std::ostream &os, const std::tuple &tpl) { +template void printElement(std::ostream &os, const std::tuple &tpl) { os << "<"; std::apply( [&os](const auto &...args) { @@ -268,8 +250,7 @@ void printElement(std::ostream &os, const std::tuple &tpl) { os << ">"; } -template -void printElement(std::ostream &os, const SP &ptr) { +template void printElement(std::ostream &os, const SP &ptr) { if (ptr) { printElement(os, *ptr); } else { @@ -277,8 +258,7 @@ void printElement(std::ostream &os, const SP &ptr) { } } -template -void printElement(std::ostream &os, const WP &ptr) { +template void printElement(std::ostream &os, const WP &ptr) { auto sp = ptr.lock(); if (sp) { printElement(os, *sp); @@ -287,8 +267,7 @@ void printElement(std::ostream &os, const WP &ptr) { } } -template -void printElement(std::ostream &os, const std::optional &opt) { +template void printElement(std::ostream &os, const std::optional &opt) { if (opt) { os << "Optional("; printElement(os, *opt); @@ -298,8 +277,7 @@ void printElement(std::ostream &os, const std::optional &opt) { } } -template -void printElement(std::ostream &os, const std::variant &var) { +template void printElement(std::ostream &os, const std::variant &var) { std::visit( [&os](const auto &value) { os << "Variant("; @@ -331,9 +309,7 @@ inline void printElement(std::ostream &os, const std::filesystem::path &path) { os << "Path(" << path << ")"; } -inline void -printElement(std::ostream &os, - const std::filesystem::directory_iterator &dirIterator) { +inline void printElement(std::ostream &os, const std::filesystem::directory_iterator &dirIterator) { os << "Directory["; bool first = true; for (const auto &entry : dirIterator) { @@ -350,10 +326,9 @@ printElement(std::ostream &os, template void printElement(std::ostream &os, const UnsupportedArg &value) { const auto *typeName = typeid(UnsupportedArg).name(); - throw std::runtime_error( - "Type "s + std::string(typeName) + - "neither supports << operator for std::ostream nor is supported " - "out-of-the-box in logging functionality."s); + throw std::runtime_error("Type "s + std::string(typeName) + + "neither supports << operator for std::ostream nor is supported " + "out-of-the-box in logging functionality."s); } } // namespace low_level_log_implementation @@ -421,8 +396,7 @@ inline void handleIosLog(LOG_LEVEL logLevel, const char *buffer) { } #endif -inline std::string getBuffer(const std::string &logMessage, - std::size_t maxLogMessageSize) { +inline std::string getBuffer(const std::string &logMessage, std::size_t maxLogMessageSize) { if (logMessage.size() > maxLogMessageSize) { return logMessage.substr(0, maxLogMessageSize) + "..."; } @@ -466,8 +440,7 @@ template void log(LOG_LEVEL logLevel, Args &&...args) { auto oss = high_level_log_implementation::createConfiguredOutputStream(); auto space = [&oss](auto &&arg) { - low_level_log_implementation::printElement( - oss, std::forward(arg)); + low_level_log_implementation::printElement(oss, std::forward(arg)); oss << ' '; }; @@ -479,8 +452,7 @@ void log(LOG_LEVEL logLevel, Args &&...args) { output.pop_back(); } - const auto buffer = - high_level_log_implementation::getBuffer(output, MaxLogSize); + const auto buffer = high_level_log_implementation::getBuffer(output, MaxLogSize); const auto *cStyleBuffer = buffer.c_str(); #ifdef __ANDROID__ diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 53ee65a904..68be27335f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -34,13 +34,12 @@ namespace rnexecutorch { // SSL intricacies manually, as it is done automagically in ObjC++/Kotlin. FetchUrlFunc_t fetchUrlFunc; -void RnExecutorchInstaller::injectJSIBindings( - jsi::Runtime *jsiRuntime, std::shared_ptr jsCallInvoker, - FetchUrlFunc_t fetchDataFromUrl, bool isEmulator) { +void RnExecutorchInstaller::injectJSIBindings(jsi::Runtime *jsiRuntime, + std::shared_ptr jsCallInvoker, + FetchUrlFunc_t fetchDataFromUrl, bool isEmulator) { fetchUrlFunc = fetchDataFromUrl; - jsiRuntime->global().setProperty(*jsiRuntime, "__rne_isEmulator", - jsi::Value(isEmulator)); + jsiRuntime->global().setProperty(*jsiRuntime, "__rne_isEmulator", jsi::Value(isEmulator)); jsiRuntime->global().setProperty( *jsiRuntime, "loadStyleTransfer", @@ -49,14 +48,12 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadSemanticSegmentation", - RnExecutorchInstaller::loadModel< - models::semantic_segmentation::BaseSemanticSegmentation>( + RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadSemanticSegmentation")); jsiRuntime->global().setProperty( *jsiRuntime, "loadInstanceSegmentation", - RnExecutorchInstaller::loadModel< - models::instance_segmentation::BaseInstanceSegmentation>( + RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadInstanceSegmentation")); jsiRuntime->global().setProperty( @@ -71,24 +68,21 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadObjectDetection", - RnExecutorchInstaller::loadModel< - models::object_detection::ObjectDetection>(jsiRuntime, jsCallInvoker, - "loadObjectDetection")); + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadObjectDetection")); jsiRuntime->global().setProperty( *jsiRuntime, "loadPoseEstimation", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadPoseEstimation")); - jsiRuntime->global().setProperty( - *jsiRuntime, "loadExecutorchModule", - RnExecutorchInstaller::loadModel( - jsiRuntime, jsCallInvoker, "loadExecutorchModule")); + jsiRuntime->global().setProperty(*jsiRuntime, "loadExecutorchModule", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadExecutorchModule")); - jsiRuntime->global().setProperty( - *jsiRuntime, "loadTokenizerModule", - RnExecutorchInstaller::loadModel( - jsiRuntime, jsCallInvoker, "loadTokenizerModule")); + jsiRuntime->global().setProperty(*jsiRuntime, "loadTokenizerModule", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadTokenizerModule")); jsiRuntime->global().setProperty( *jsiRuntime, "loadImageEmbeddings", @@ -102,8 +96,7 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadLLM", - RnExecutorchInstaller::loadModel( - jsiRuntime, jsCallInvoker, "loadLLM")); + RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, "loadLLM")); jsiRuntime->global().setProperty( *jsiRuntime, "loadPrivacyFilter", @@ -112,13 +105,11 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadOCR", - RnExecutorchInstaller::loadModel( - jsiRuntime, jsCallInvoker, "loadOCR")); + RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, "loadOCR")); - jsiRuntime->global().setProperty( - *jsiRuntime, "loadVerticalOCR", - RnExecutorchInstaller::loadModel( - jsiRuntime, jsCallInvoker, "loadVerticalOCR")); + jsiRuntime->global().setProperty(*jsiRuntime, "loadVerticalOCR", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadVerticalOCR")); jsiRuntime->global().setProperty( *jsiRuntime, "loadSpeechToText", @@ -132,8 +123,7 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadVAD", - RnExecutorchInstaller::loadModel< - models::voice_activity_detection::VoiceActivityDetection>( + RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadVAD")); threads::utils::unsafeSetupThreadPool(); diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h index 887a8ac0ea..f11d6b7d36 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h @@ -21,34 +21,29 @@ using namespace facebook; class RnExecutorchInstaller { public: - static void - injectJSIBindings(jsi::Runtime *jsiRuntime, - std::shared_ptr jsCallInvoker, - FetchUrlFunc_t fetchDataFromUrl, bool isEmulator); + static void injectJSIBindings(jsi::Runtime *jsiRuntime, + std::shared_ptr jsCallInvoker, + FetchUrlFunc_t fetchDataFromUrl, bool isEmulator); private: template - requires meta::ValidConstructorTraits && - meta::CallInvokerLastInConstructor && + requires meta::ValidConstructorTraits && meta::CallInvokerLastInConstructor && meta::ProvidesMemoryLowerBound - static jsi::Function - loadModel(jsi::Runtime *jsiRuntime, - std::shared_ptr jsCallInvoker, - const std::string &loadFunctionName) { + static jsi::Function loadModel(jsi::Runtime *jsiRuntime, + std::shared_ptr jsCallInvoker, + const std::string &loadFunctionName) { return jsi::Function::createFromHostFunction( - *jsiRuntime, jsi::PropNameID::forAscii(*jsiRuntime, loadFunctionName), - 0, - [jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue, - const jsi::Value *args, size_t count) -> jsi::Value { - constexpr std::size_t expectedCount = std::tuple_size_v< - typename meta::ConstructorTraits::arg_types>; + *jsiRuntime, jsi::PropNameID::forAscii(*jsiRuntime, loadFunctionName), 0, + [jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue, const jsi::Value *args, + size_t count) -> jsi::Value { + constexpr std::size_t expectedCount = + std::tuple_size_v::arg_types>; // count doesn't account for the JSCallInvoker if (count != expectedCount - 1) { char errorMessage[100]; - std::snprintf( - errorMessage, sizeof(errorMessage), - "Argument count mismatch, was expecting: %zu but got: %zu", - expectedCount, count); + std::snprintf(errorMessage, sizeof(errorMessage), + "Argument count mismatch, was expecting: %zu but got: %zu", expectedCount, + count); throw jsi::JSError(runtime, errorMessage); } @@ -56,59 +51,47 @@ class RnExecutorchInstaller { // access), then dispatch the heavy model construction to a background // thread and return a Promise. auto constructorArgs = - meta::createConstructorArgsWithCallInvoker(args, runtime, - jsCallInvoker); + meta::createConstructorArgsWithCallInvoker(args, runtime, jsCallInvoker); return Promise::createPromise( runtime, jsCallInvoker, - [jsCallInvoker, constructorArgs = std::move(constructorArgs)]( - std::shared_ptr promise) { - threads::GlobalThreadPool::detach([jsCallInvoker, promise, - constructorArgs = std::move( - constructorArgs)]() { - try { - auto modelImplementationPtr = std::apply( - [](auto &&...unpackedArgs) { - return std::make_shared( - std::forward( - unpackedArgs)...); - }, - std::move(constructorArgs)); + [jsCallInvoker, + constructorArgs = std::move(constructorArgs)](std::shared_ptr promise) { + threads::GlobalThreadPool::detach( + [jsCallInvoker, promise, constructorArgs = std::move(constructorArgs)]() { + try { + auto modelImplementationPtr = std::apply( + [](auto &&...unpackedArgs) { + return std::make_shared( + std::forward(unpackedArgs)...); + }, + std::move(constructorArgs)); - auto modelHostObject = - std::make_shared>( + auto modelHostObject = std::make_shared>( modelImplementationPtr, jsCallInvoker); - auto memoryLowerBound = - modelImplementationPtr->getMemoryLowerBound(); + auto memoryLowerBound = modelImplementationPtr->getMemoryLowerBound(); - jsCallInvoker->invokeAsync([promise, modelHostObject, - memoryLowerBound]( - jsi::Runtime &rt) { - auto jsiObject = jsi::Object::createFromHostObject( - rt, modelHostObject); - jsiObject.setExternalMemoryPressure(rt, memoryLowerBound); - promise->resolve(std::move(jsiObject)); - }); - } catch (const rnexecutorch::RnExecutorchError &e) { - auto code = e.getNumericCode(); - auto msg = std::string(e.what()); - jsCallInvoker->invokeAsync( - [promise, code, msg](jsi::Runtime &rt) { - promise->reject( - makeRnExecutorchErrorValue(rt, code, msg)); + jsCallInvoker->invokeAsync([promise, modelHostObject, + memoryLowerBound](jsi::Runtime &rt) { + auto jsiObject = jsi::Object::createFromHostObject(rt, modelHostObject); + jsiObject.setExternalMemoryPressure(rt, memoryLowerBound); + promise->resolve(std::move(jsiObject)); }); - } catch (const std::exception &e) { - jsCallInvoker->invokeAsync( - [promise, msg = std::string(e.what())]() { - promise->reject(msg); + } catch (const rnexecutorch::RnExecutorchError &e) { + auto code = e.getNumericCode(); + auto msg = std::string(e.what()); + jsCallInvoker->invokeAsync([promise, code, msg](jsi::Runtime &rt) { + promise->reject(makeRnExecutorchErrorValue(rt, code, msg)); }); - } catch (...) { - jsCallInvoker->invokeAsync([promise]() { - promise->reject(std::string("Unknown error")); + } catch (const std::exception &e) { + jsCallInvoker->invokeAsync( + [promise, msg = std::string(e.what())]() { promise->reject(msg); }); + } catch (...) { + jsCallInvoker->invokeAsync( + [promise]() { promise->reject(std::string("Unknown error")); }); + } }); - } - }); }); }); } diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp index 76e0fb90c7..2cf9dd57f4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp @@ -11,16 +11,15 @@ namespace rnexecutorch { using namespace facebook; using namespace executorch::extension::llm; -TokenizerModule::TokenizerModule( - std::string source, std::shared_ptr callInvoker) +TokenizerModule::TokenizerModule(std::string source, + std::shared_ptr callInvoker) : tokenizer(std::make_unique()) { auto status = tokenizer->load(source); if (status != tokenizers::Error::Ok) { - throw RnExecutorchError( - RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occurred while loading tokenizer"); + throw RnExecutorchError(RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while loading tokenizer"); }; std::filesystem::path modelPath{source}; memorySizeLowerBound = std::filesystem::file_size(modelPath); @@ -35,29 +34,25 @@ std::vector TokenizerModule::encode(std::string s) const { // setting any of bos or eos arguments to value other than provided constant // ( which is 0) will result in running the post_processor with // 'add_special_token' flag - auto encodeResult = - tokenizer->encode(s, numOfAddedBoSTokens, numOfAddedEoSTokens); + auto encodeResult = tokenizer->encode(s, numOfAddedBoSTokens, numOfAddedEoSTokens); if (!encodeResult.ok()) { - throw RnExecutorchError( - RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occurred while encoding: " + - std::to_string(static_cast(encodeResult.error()))); + throw RnExecutorchError(RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while encoding: " + + std::to_string(static_cast(encodeResult.error()))); } return encodeResult.get(); } -std::string TokenizerModule::decode(std::vector vec, - bool skipSpecialTokens) const { +std::string TokenizerModule::decode(std::vector vec, bool skipSpecialTokens) const { if (!tokenizer) { THROW_NOT_LOADED_ERROR(); } auto decodeResult = tokenizer->decode(vec, skipSpecialTokens); if (!decodeResult.ok()) { - throw RnExecutorchError( - RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occurred while decoding: " + - std::to_string(static_cast(decodeResult.error()))); + throw RnExecutorchError(RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while decoding: " + + std::to_string(static_cast(decodeResult.error()))); } return decodeResult.get(); @@ -76,10 +71,9 @@ std::string TokenizerModule::idToToken(uint64_t tokenId) const { } auto result = tokenizer->id_to_piece(tokenId); if (!result.ok()) { - throw RnExecutorchError( - RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occurred while converting id to token: " + - std::to_string(static_cast(result.error()))); + throw RnExecutorchError(RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while converting id to token: " + + std::to_string(static_cast(result.error()))); } return result.get(); } @@ -91,16 +85,13 @@ uint64_t TokenizerModule::tokenToId(std::string token) const { auto result = tokenizer->piece_to_id(token); if (!result.ok()) { - throw RnExecutorchError( - RnExecutorchErrorCode::TokenizerError, - "Unexpected issue occurred while converting token to id: " + - std::to_string(static_cast(result.error()))); + throw RnExecutorchError(RnExecutorchErrorCode::TokenizerError, + "Unexpected issue occurred while converting token to id: " + + std::to_string(static_cast(result.error()))); } return result.get(); } -std::size_t TokenizerModule::getMemoryLowerBound() const noexcept { - return memorySizeLowerBound; -} +std::size_t TokenizerModule::getMemoryLowerBound() const noexcept { return memorySizeLowerBound; } } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h index 3c90b25557..52d2ecaf97 100644 --- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h +++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h @@ -9,18 +9,13 @@ using namespace facebook; class TokenizerModule { public: - explicit TokenizerModule(std::string source, - std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] std::vector - encode(std::string s) const; - [[nodiscard("Registered non-void function")]] std::string - decode(std::vector vec, bool skipSpecialTokens) const; - [[nodiscard("Registered non-void function")]] std::string - idToToken(uint64_t tokenId) const; - [[nodiscard("Registered non-void function")]] uint64_t - tokenToId(std::string token) const; - [[nodiscard("Registered non-void function")]] std::size_t - getVocabSize() const; + explicit TokenizerModule(std::string source, std::shared_ptr callInvoker); + [[nodiscard("Registered non-void function")]] std::vector encode(std::string s) const; + [[nodiscard("Registered non-void function")]] std::string decode(std::vector vec, + bool skipSpecialTokens) const; + [[nodiscard("Registered non-void function")]] std::string idToToken(uint64_t tokenId) const; + [[nodiscard("Registered non-void function")]] uint64_t tokenToId(std::string token) const; + [[nodiscard("Registered non-void function")]] std::size_t getVocabSize() const; std::size_t getMemoryLowerBound() const noexcept; private: @@ -28,6 +23,5 @@ class TokenizerModule { std::size_t memorySizeLowerBound{0}; }; -REGISTER_CONSTRUCTOR(TokenizerModule, std::string, - std::shared_ptr); +REGISTER_CONSTRUCTOR(TokenizerModule, std::string, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/FFT.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/FFT.cpp index 4884938294..fdb8190e03 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/FFT.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/FFT.cpp @@ -13,8 +13,7 @@ FFT::~FFT() { } void FFT::doFFT(float *in, std::vector> &out) { - pffft_transform_ordered(pffftSetup_, in, - reinterpret_cast(out.data()), work_, + pffft_transform_ordered(pffftSetup_, in, reinterpret_cast(out.data()), work_, PFFFT_FORWARD); } diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp index 3e73a3d8a4..25e0c4decf 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp @@ -16,12 +16,10 @@ namespace rnexecutorch { extern FetchUrlFunc_t fetchUrlFunc; namespace image_processing { std::vector colorMatToVector(const cv::Mat &mat) { - return colorMatToVector(mat, cv::Scalar(0.0, 0.0, 0.0), - cv::Scalar(1.0, 1.0, 1.0)); + return colorMatToVector(mat, cv::Scalar(0.0, 0.0, 0.0), cv::Scalar(1.0, 1.0, 1.0)); } -std::vector colorMatToVector(const cv::Mat &mat, cv::Scalar mean, - cv::Scalar variance) { +std::vector colorMatToVector(const cv::Mat &mat, cv::Scalar mean, cv::Scalar variance) { int pixelCount = mat.cols * mat.rows; std::vector v(pixelCount * 3); @@ -29,19 +27,15 @@ std::vector colorMatToVector(const cv::Mat &mat, cv::Scalar mean, int row = i / mat.cols; int col = i % mat.cols; cv::Vec3b pixel = mat.at(row, col); - v[0 * pixelCount + i] = - (pixel[0] - mean[0] * 255.0) / (variance[0] * 255.0); - v[1 * pixelCount + i] = - (pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0); - v[2 * pixelCount + i] = - (pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0); + v[0 * pixelCount + i] = (pixel[0] - mean[0] * 255.0) / (variance[0] * 255.0); + v[1 * pixelCount + i] = (pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0); + v[2 * pixelCount + i] = (pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0); } return v; } -cv::Mat bufferToColorMat(const std::span &buffer, - cv::Size matSize) { +cv::Mat bufferToColorMat(const std::span &buffer, cv::Size matSize) { cv::Mat mat(matSize, CV_8UC3); int pixelCount = matSize.width * matSize.height; @@ -100,9 +94,8 @@ cv::Mat readImage(const std::string &imageURI) { } else if (imageURI.starts_with("http")) { // remote file std::vector imageData = fetchUrlFunc(imageURI); - image = cv::imdecode( - cv::Mat(1, imageData.size(), CV_8UC1, (void *)imageData.data()), - cv::IMREAD_COLOR); + image = cv::imdecode(cv::Mat(1, imageData.size(), CV_8UC1, (void *)imageData.data()), + cv::IMREAD_COLOR); } else { // fallback to raw base64 content auto data = base64_decode(imageURI); @@ -118,23 +111,18 @@ cv::Mat readImage(const std::string &imageURI) { return image; } -TensorPtr getTensorFromMatrix(const std::vector &tensorDims, - const cv::Mat &matrix) { - return executorch::extension::make_tensor_ptr(tensorDims, - colorMatToVector(matrix)); +TensorPtr getTensorFromMatrix(const std::vector &tensorDims, const cv::Mat &matrix) { + return executorch::extension::make_tensor_ptr(tensorDims, colorMatToVector(matrix)); } -TensorPtr getTensorFromMatrix(const std::vector &tensorDims, - const cv::Mat &matrix, cv::Scalar mean, - cv::Scalar variance) { - return executorch::extension::make_tensor_ptr( - tensorDims, colorMatToVector(matrix, mean, variance)); +TensorPtr getTensorFromMatrix(const std::vector &tensorDims, const cv::Mat &matrix, + cv::Scalar mean, cv::Scalar variance) { + return executorch::extension::make_tensor_ptr(tensorDims, + colorMatToVector(matrix, mean, variance)); } -TensorPtr getTensorFromMatrixGray(const std::vector &tensorDims, - const cv::Mat &matrix) { - return executorch::extension::make_tensor_ptr(tensorDims, - grayMatToVector(matrix)); +TensorPtr getTensorFromMatrixGray(const std::vector &tensorDims, const cv::Mat &matrix) { + return executorch::extension::make_tensor_ptr(tensorDims, grayMatToVector(matrix)); } std::vector grayMatToVector(const cv::Mat &mat) { @@ -153,37 +141,29 @@ std::vector grayMatToVector(const cv::Mat &mat) { cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor) { auto resultData = static_cast(tensor.const_data_ptr()); - return bufferToColorMat(std::span(resultData, tensor.numel()), - size); + return bufferToColorMat(std::span(resultData, tensor.numel()), size); } cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize) { cv::Size inputSize = inputImage.size(); - const float heightRatio = - static_cast(targetSize.height) / inputSize.height; - const float widthRatio = - static_cast(targetSize.width) / inputSize.width; + const float heightRatio = static_cast(targetSize.height) / inputSize.height; + const float widthRatio = static_cast(targetSize.width) / inputSize.width; const float resizeRatio = std::min(heightRatio, widthRatio); const int newWidth = inputSize.width * resizeRatio; const int newHeight = inputSize.height * resizeRatio; cv::Mat resizedImg; - cv::resize(inputImage, resizedImg, cv::Size(newWidth, newHeight), 0, 0, - cv::INTER_AREA); + cv::resize(inputImage, resizedImg, cv::Size(newWidth, newHeight), 0, 0, cv::INTER_AREA); constexpr int minCornerPatchSize = 1; constexpr int cornerPatchFractionSize = 30; - int cornerPatchSize = - std::min(inputSize.height, inputSize.width) / cornerPatchFractionSize; + int cornerPatchSize = std::min(inputSize.height, inputSize.width) / cornerPatchFractionSize; cornerPatchSize = std::max(minCornerPatchSize, cornerPatchSize); const std::array corners = { inputImage(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), - inputImage(cv::Rect(inputSize.width - cornerPatchSize, 0, cornerPatchSize, - cornerPatchSize)), - inputImage(cv::Rect(0, inputSize.height - cornerPatchSize, - cornerPatchSize, cornerPatchSize)), - inputImage(cv::Rect(inputSize.width - cornerPatchSize, - inputSize.height - cornerPatchSize, cornerPatchSize, - cornerPatchSize))}; + inputImage(cv::Rect(inputSize.width - cornerPatchSize, 0, cornerPatchSize, cornerPatchSize)), + inputImage(cv::Rect(0, inputSize.height - cornerPatchSize, cornerPatchSize, cornerPatchSize)), + inputImage(cv::Rect(inputSize.width - cornerPatchSize, inputSize.height - cornerPatchSize, + cornerPatchSize, cornerPatchSize))}; // We choose the color of the padding based on a mean of colors in the corners // of an image. @@ -208,17 +188,17 @@ cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize) { const int right = deltaW - left; cv::Mat centeredImg; - cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, - cv::BORDER_CONSTANT, backgroundScalar); + cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, cv::BORDER_CONSTANT, + backgroundScalar); return centeredImg; } -std::pair -readImageToTensor(const std::string &path, - const std::vector &tensorDims, - bool maintainAspectRatio, std::optional normMean, - std::optional normStd) { +std::pair readImageToTensor(const std::string &path, + const std::vector &tensorDims, + bool maintainAspectRatio, + std::optional normMean, + std::optional normStd) { cv::Mat input = image_processing::readImage(path); cv::Size imageSize = input.size(); @@ -228,11 +208,10 @@ readImageToTensor(const std::string &path, "Unexpected tensor size, expected at least 2 dimensions " "but got: %zu.", tensorDims.size()); - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - errorMessage); + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } - cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], - tensorDims[tensorDims.size() - 2]); + cv::Size tensorSize = + cv::Size(tensorDims[tensorDims.size() - 1], tensorDims[tensorDims.size() - 2]); if (maintainAspectRatio) { input = resizePadded(input, tensorSize); @@ -243,9 +222,9 @@ readImageToTensor(const std::string &path, cv::cvtColor(input, input, cv::COLOR_BGR2RGB); if (normMean.has_value() && normStd.has_value()) { - return {image_processing::getTensorFromMatrix( - tensorDims, input, normMean.value(), normStd.value()), - imageSize}; + return { + image_processing::getTensorFromMatrix(tensorDims, input, normMean.value(), normStd.value()), + imageSize}; } return {image_processing::getTensorFromMatrix(tensorDims, input), imageSize}; } diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h index 8b371f87e3..d00c14050e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h @@ -13,25 +13,20 @@ using executorch::aten::Tensor; using executorch::extension::TensorPtr; /// @brief Convert a OpenCV matrix to channel-first vector representation -std::vector colorMatToVector(const cv::Mat &mat, cv::Scalar mean, - cv::Scalar variance); +std::vector colorMatToVector(const cv::Mat &mat, cv::Scalar mean, cv::Scalar variance); /// @brief Convert a OpenCV matrix to channel-first vector representation std::vector colorMatToVector(const cv::Mat &mat); /// @brief Convert a channel-first representation of an RGB image to OpenCV /// matrix -cv::Mat bufferToColorMat(const std::span &buffer, - cv::Size matSize); +cv::Mat bufferToColorMat(const std::span &buffer, cv::Size matSize); std::string saveToTempFile(const cv::Mat &image); /// @brief Read image in a BGR format to a cv::Mat cv::Mat readImage(const std::string &imageURI); -TensorPtr getTensorFromMatrix(const std::vector &tensorDims, - const cv::Mat &mat); -TensorPtr getTensorFromMatrix(const std::vector &tensorDims, - const cv::Mat &matrix, cv::Scalar mean, - cv::Scalar variance); +TensorPtr getTensorFromMatrix(const std::vector &tensorDims, const cv::Mat &mat); +TensorPtr getTensorFromMatrix(const std::vector &tensorDims, const cv::Mat &matrix, + cv::Scalar mean, cv::Scalar variance); cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor); -TensorPtr getTensorFromMatrixGray(const std::vector &tensorDims, - const cv::Mat &matrix); +TensorPtr getTensorFromMatrixGray(const std::vector &tensorDims, const cv::Mat &matrix); std::vector grayMatToVector(const cv::Mat &mat); /** * @brief Resizes an image to fit within target dimensions while preserving @@ -48,12 +43,11 @@ cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize); /// maintain the original aspect ratio. The rest of the tensor will be filled /// padding. /// @return Returns a tensor pointer and the original size of the image. -std::pair -readImageToTensor(const std::string &path, - const std::vector &tensorDims, - bool maintainAspectRatio = false, - std::optional normMean = std::nullopt, - std::optional normStd = std::nullopt); +std::pair readImageToTensor(const std::string &path, + const std::vector &tensorDims, + bool maintainAspectRatio = false, + std::optional normMean = std::nullopt, + std::optional normStd = std::nullopt); /** * @brief Applies sigmoid activation to logits and converts to uint8 binary mask * @param logits Input matrix containing raw logits (pre-sigmoid) diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp index 4b19cc99ba..a1c6a1fea7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp @@ -37,9 +37,8 @@ void softmaxWithTemperature(std::span input, float temperature) { } if (temperature <= 0.0F) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidConfig, - "Temperature must be greater than 0 for softmax with temperature!"); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Temperature must be greater than 0 for softmax with temperature!"); } const auto maxElement = *std::ranges::max_element(input); @@ -57,8 +56,7 @@ void softmaxWithTemperature(std::span input, float temperature) { } void normalize(std::span input) { - const auto sumOfSquares = - std::inner_product(input.begin(), input.end(), input.begin(), 0.0F); + const auto sumOfSquares = std::inner_product(input.begin(), input.end(), input.begin(), 0.0F); constexpr auto kEpsilon = 1.0e-15F; @@ -76,8 +74,7 @@ std::vector meanPooling(std::span modelOutput, ss << "Invalid dimensions for mean pooling, expected model output size to " "be divisible " << "by the size of attention mask but got size: " << modelOutput.size() - << " for model output and size: " << attnMask.size() - << " for attention mask"; + << " for model output and size: " << attnMask.size() << " for attention mask"; throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, ss.str()); } @@ -105,9 +102,7 @@ std::vector meanPooling(std::span modelOutput, return result; } -template bool isClose(T a, T b, T atol) { - return std::abs(a - b) <= atol; -} +template bool isClose(T a, T b, T atol) { return std::abs(a - b) <= atol; } template bool isClose(float, float, float); template bool isClose(double, double, double); diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h index 3a292cfbbf..588070617d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h @@ -68,8 +68,7 @@ std::vector meanPooling(std::span modelOutput, * @brief Checks if two floating-point numbers are considered equal. */ template -bool isClose(T a, T b, - T atol = std::numeric_limits::epsilon() * static_cast(10)); +bool isClose(T a, T b, T atol = std::numeric_limits::epsilon() * static_cast(10)); extern template bool isClose(float, float, float); extern template bool isClose(double, double, double); diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Sequential.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/Sequential.h index a872352f68..1a32b51ad1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Sequential.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Sequential.h @@ -28,14 +28,12 @@ namespace rnexecutorch::sequential { * @return A std::vector containing the repeated elements in order. */ template -std::vector repeatInterleave(std::span data, - std::span repetitions) { +std::vector repeatInterleave(std::span data, std::span repetitions) { if (data.size() != repetitions.size()) { throw std::invalid_argument( "repeatInterleave(): repetitions vector must be the same size as data," " expected " + - std::to_string(data.size()) + " but got " + - std::to_string(repetitions.size())); + std::to_string(data.size()) + " but got " + std::to_string(repetitions.size())); } IType totalReps = std::reduce(repetitions.begin(), repetitions.end()); @@ -50,4 +48,4 @@ std::vector repeatInterleave(std::span data, return result; } -} // namespace rnexecutorch::sequential \ No newline at end of file +} // namespace rnexecutorch::sequential diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.cpp index a79bee9b8f..4e33605360 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.cpp @@ -5,9 +5,7 @@ static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz" "0123456789+/"; -static inline bool is_base64(BYTE c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} +static inline bool is_base64(BYTE c) { return (isalnum(c) || (c == '+') || (c == '/')); } std::string base64_encode(BYTE const *buf, unsigned int bufLen) { std::string ret; @@ -20,10 +18,8 @@ std::string base64_encode(BYTE const *buf, unsigned int bufLen) { char_array_3[i++] = *(buf++); if (i == 3) { char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = - ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = - ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); char_array_4[3] = char_array_3[2] & 0x3f; for (i = 0; (i < 4); i++) { @@ -39,10 +35,8 @@ std::string base64_encode(BYTE const *buf, unsigned int bufLen) { } char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = - ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = - ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); + char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); + char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); char_array_4[3] = char_array_3[2] & 0x3f; for (j = 0; (j < i + 1); j++) { @@ -65,8 +59,7 @@ std::vector base64_decode(std::string const &encoded_string) { BYTE char_array_4[4], char_array_3[3]; std::vector ret; - while (in_len-- && (encoded_string[in_] != '=') && - is_base64(encoded_string[in_])) { + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { char_array_4[i++] = encoded_string[in_]; in_++; if (i == 4) { @@ -74,10 +67,8 @@ std::vector base64_decode(std::string const &encoded_string) { char_array_4[i] = base64_chars.find(char_array_4[i]); } - char_array_3[0] = - (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = - ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (i = 0; (i < 3); i++) { @@ -97,8 +88,7 @@ std::vector base64_decode(std::string const &encoded_string) { } char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = - ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; for (j = 0; (j < i - 1); j++) { @@ -107,4 +97,4 @@ std::vector base64_decode(std::string const &encoded_string) { } return ret; -} \ No newline at end of file +} diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.h index 5ebaaa84f9..16e2ee0ab2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/base64.h @@ -43,4 +43,4 @@ typedef unsigned char BYTE; std::string base64_encode(BYTE const *buf, unsigned int bufLen); std::vector base64_decode(std::string const &); -#endif \ No newline at end of file +#endif diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp index 38282883f3..3e09dfd46c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.cpp @@ -10,13 +10,15 @@ namespace rnexecutorch::dsp { using std::numbers::pi_v; -//https://www.mathworks.com/help/signal/ref/hann.html +// https://www.mathworks.com/help/signal/ref/hann.html std::vector hannWindow(size_t size) { - std::vector window(size); - for (size_t i = 0; i < size; i++) { - window[i] = 0.5f * (1.0f - std::cosf(2.0f * pi_v * static_cast(i) / static_cast(size))); - } - return window; -} + std::vector window(size); + for (size_t i = 0; i < size; i++) { + window[i] = + 0.5f * + (1.0f - std::cosf(2.0f * pi_v * static_cast(i) / static_cast(size))); + } + return window; +} } // namespace rnexecutorch::dsp diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h index 7eaa26d831..4cef5a7e0e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/dsp.h @@ -6,7 +6,7 @@ namespace rnexecutorch::dsp { std::vector hannWindow(size_t size); -std::vector stftFromWaveform(std::span waveform, - size_t fftWindowSize, size_t hopSize); +std::vector stftFromWaveform(std::span waveform, size_t fftWindowSize, + size_t hopSize); } // namespace rnexecutorch::dsp diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp index aeda796d0b..01b168e3ab 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/gzip.cpp @@ -15,17 +15,14 @@ constexpr size_t kChunkSize = 16 * 1024; // 16 KiB stream buffer size_t deflateSize(const std::string &input) { z_stream strm{}; - if (::deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, - MAX_WBITS + kGzipWrapper, kMemLevel, + if (::deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, MAX_WBITS + kGzipWrapper, kMemLevel, Z_DEFAULT_STRATEGY) != Z_OK) { - throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, - "deflateInit2 failed"); + throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, "deflateInit2 failed"); } size_t outSize = 0; - strm.next_in = reinterpret_cast( - const_cast(input.data())); + strm.next_in = reinterpret_cast(const_cast(input.data())); strm.avail_in = static_cast(input.size()); std::vector buf(kChunkSize); @@ -37,8 +34,7 @@ size_t deflateSize(const std::string &input) { ret = ::deflate(&strm, strm.avail_in ? Z_NO_FLUSH : Z_FINISH); if (ret == Z_STREAM_ERROR) { ::deflateEnd(&strm); - throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, - "deflate stream error"); + throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, "deflate stream error"); } outSize += buf.size() - strm.avail_out; diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h index 5d2f885021..86c2ce6b2e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h @@ -16,7 +16,6 @@ struct JSTensorViewOut { JSTensorViewOut(std::vector sizes, ScalarType scalarType, std::shared_ptr dataPtr) - : dataPtr(std::move(dataPtr)), sizes(std::move(sizes)), - scalarType(scalarType) {} + : dataPtr(std::move(dataPtr)), sizes(std::move(sizes)), scalarType(scalarType) {} }; } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index e4209b2f79..b974488129 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -47,20 +47,16 @@ inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) { return static_cast(val.asNumber()); } -template <> -inline bool getValue(const jsi::Value &val, jsi::Runtime &runtime) { +template <> inline bool getValue(const jsi::Value &val, jsi::Runtime &runtime) { return val.asBool(); } -template <> -inline std::string getValue(const jsi::Value &val, - jsi::Runtime &runtime) { +template <> inline std::string getValue(const jsi::Value &val, jsi::Runtime &runtime) { return val.getString(runtime).utf8(runtime); } template <> -inline std::u32string getValue(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::u32string getValue(const jsi::Value &val, jsi::Runtime &runtime) { std::string utf8 = getValue(val, runtime); std::wstring_convert, char32_t> conv; @@ -69,15 +65,12 @@ inline std::u32string getValue(const jsi::Value &val, template <> inline std::shared_ptr -getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { - return std::make_shared( - val.asObject(runtime).asFunction(runtime)); +getValue>(const jsi::Value &val, jsi::Runtime &runtime) { + return std::make_shared(val.asObject(runtime).asFunction(runtime)); } template <> -inline JSTensorViewIn getValue(const jsi::Value &val, - jsi::Runtime &runtime) { +inline JSTensorViewIn getValue(const jsi::Value &val, jsi::Runtime &runtime) { jsi::Object obj = val.asObject(runtime); JSTensorViewIn tensorView; @@ -90,8 +83,7 @@ inline JSTensorViewIn getValue(const jsi::Value &val, tensorView.sizes.reserve(numShapeDims); for (size_t i = 0; i < numShapeDims; ++i) { - int32_t dim = - getValue(shapeArray.getValueAtIndex(runtime, i), runtime); + int32_t dim = getValue(shapeArray.getValueAtIndex(runtime, i), runtime); tensorView.sizes.push_back(dim); } @@ -106,27 +98,21 @@ inline JSTensorViewIn getValue(const jsi::Value &val, tensorView.dataPtr = arrayBuffer.data(runtime); } else { // Handle typed arrays (Float32Array, Int32Array, etc.) - const bool isValidTypedArray = dataObj.hasProperty(runtime, "buffer") && - dataObj.hasProperty(runtime, "byteOffset") && - dataObj.hasProperty(runtime, "byteLength") && - dataObj.hasProperty(runtime, "length"); + const bool isValidTypedArray = + dataObj.hasProperty(runtime, "buffer") && dataObj.hasProperty(runtime, "byteOffset") && + dataObj.hasProperty(runtime, "byteLength") && dataObj.hasProperty(runtime, "length"); if (!isValidTypedArray) { throw jsi::JSError(runtime, "Data must be an ArrayBuffer or TypedArray"); } jsi::Value bufferValue = dataObj.getProperty(runtime, "buffer"); - if (!bufferValue.isObject() || - !bufferValue.asObject(runtime).isArrayBuffer(runtime)) { - throw jsi::JSError(runtime, - "TypedArray buffer property must be an ArrayBuffer"); + if (!bufferValue.isObject() || !bufferValue.asObject(runtime).isArrayBuffer(runtime)) { + throw jsi::JSError(runtime, "TypedArray buffer property must be an ArrayBuffer"); } - jsi::ArrayBuffer arrayBuffer = - bufferValue.asObject(runtime).getArrayBuffer(runtime); - size_t byteOffset = - getValue(dataObj.getProperty(runtime, "byteOffset"), runtime); + jsi::ArrayBuffer arrayBuffer = bufferValue.asObject(runtime).getArrayBuffer(runtime); + size_t byteOffset = getValue(dataObj.getProperty(runtime, "byteOffset"), runtime); - tensorView.dataPtr = - static_cast(arrayBuffer.data(runtime)) + byteOffset; + tensorView.dataPtr = static_cast(arrayBuffer.data(runtime)) + byteOffset; } return tensorView; } @@ -135,8 +121,7 @@ inline JSTensorViewIn getValue(const jsi::Value &val, // enables querying with std::string_view). template <> inline std::set> -getValue>>(const jsi::Value &val, - jsi::Runtime &runtime) { +getValue>>(const jsi::Value &val, jsi::Runtime &runtime) { jsi::Array array = val.asObject(runtime).asArray(runtime); size_t length = array.size(runtime); @@ -151,41 +136,34 @@ getValue>>(const jsi::Value &val, // Helper function to convert typed arrays to std::span template -inline std::span getTypedArrayAsSpan(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::span getTypedArrayAsSpan(const jsi::Value &val, jsi::Runtime &runtime) { jsi::Object obj = val.asObject(runtime); - const bool isValidTypedArray = obj.hasProperty(runtime, "buffer") && - obj.hasProperty(runtime, "byteOffset") && - obj.hasProperty(runtime, "byteLength") && - obj.hasProperty(runtime, "length"); + const bool isValidTypedArray = + obj.hasProperty(runtime, "buffer") && obj.hasProperty(runtime, "byteOffset") && + obj.hasProperty(runtime, "byteLength") && obj.hasProperty(runtime, "length"); if (!isValidTypedArray) { throw jsi::JSError(runtime, "Value must be a TypedArray"); } // Get the underlying ArrayBuffer jsi::Value bufferValue = obj.getProperty(runtime, "buffer"); - if (!bufferValue.isObject() || - !bufferValue.asObject(runtime).isArrayBuffer(runtime)) { - throw jsi::JSError(runtime, - "TypedArray buffer property must be an ArrayBuffer"); + if (!bufferValue.isObject() || !bufferValue.asObject(runtime).isArrayBuffer(runtime)) { + throw jsi::JSError(runtime, "TypedArray buffer property must be an ArrayBuffer"); } - jsi::ArrayBuffer arrayBuffer = - bufferValue.asObject(runtime).getArrayBuffer(runtime); - size_t byteOffset = - getValue(obj.getProperty(runtime, "byteOffset"), runtime); + jsi::ArrayBuffer arrayBuffer = bufferValue.asObject(runtime).getArrayBuffer(runtime); + size_t byteOffset = getValue(obj.getProperty(runtime, "byteOffset"), runtime); size_t length = getValue(obj.getProperty(runtime, "length"), runtime); - T *dataPtr = reinterpret_cast( - static_cast(arrayBuffer.data(runtime)) + byteOffset); + T *dataPtr = + reinterpret_cast(static_cast(arrayBuffer.data(runtime)) + byteOffset); return {dataPtr, length}; } template -inline std::vector getArrayAsVector(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::vector getArrayAsVector(const jsi::Value &val, jsi::Runtime &runtime) { jsi::Array array = val.asObject(runtime).asArray(runtime); const size_t length = array.size(runtime); std::vector result; @@ -200,22 +178,20 @@ inline std::vector getArrayAsVector(const jsi::Value &val, // Template specializations for std::vector types template <> -inline std::vector -getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::vector getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getArrayAsVector(val, runtime); } template <> -inline std::vector -getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::vector getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getArrayAsVector(val, runtime); } template <> -inline std::vector -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { +inline std::vector getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getArrayAsVector(val, runtime); } @@ -227,8 +203,7 @@ inline std::vector getValue>(const jsi::Value &val, template <> inline std::vector> -getValue>>(const jsi::Value &val, - jsi::Runtime &runtime) { +getValue>>(const jsi::Value &val, jsi::Runtime &runtime) { jsi::Array array = val.asObject(runtime).asArray(runtime); const size_t length = array.size(runtime); std::vector> result; @@ -242,33 +217,31 @@ getValue>>(const jsi::Value &val, } template <> -inline std::vector -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { +inline std::vector getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getArrayAsVector(val, runtime); } template <> -inline std::vector -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { +inline std::vector getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getArrayAsVector(val, runtime); } // Template specializations for std::span types template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::span getValue>(const jsi::Value &val, jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } template <> -inline std::span -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { +inline std::span getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::span getValue>(const jsi::Value &val, jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } @@ -279,8 +252,8 @@ inline std::span getValue>(const jsi::Value &val, } template <> -inline std::span -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { +inline std::span getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } @@ -291,14 +264,13 @@ inline std::span getValue>(const jsi::Value &val, } template <> -inline std::span -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { +inline std::span getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } template <> -inline std::span getValue>(const jsi::Value &val, - jsi::Runtime &runtime) { +inline std::span getValue>(const jsi::Value &val, jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } @@ -315,15 +287,14 @@ inline std::span getValue>(const jsi::Value &val, } template <> -inline std::span -getValue>(const jsi::Value &val, jsi::Runtime &runtime) { +inline std::span getValue>(const jsi::Value &val, + jsi::Runtime &runtime) { return getTypedArrayAsSpan(val, runtime); } template <> inline models::llm::MultimodalInputs -getValue(const jsi::Value &val, - jsi::Runtime &runtime) { +getValue(const jsi::Value &val, jsi::Runtime &runtime) { models::llm::MultimodalInputs multimodalInputs; jsi::Object obj = val.asObject(runtime); @@ -351,13 +322,11 @@ getValue(const jsi::Value &val, // return jsi::Value or jsi::Object. For each type being returned // we add a function here. -inline jsi::Value getJsiValue(std::shared_ptr valuePtr, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(std::shared_ptr valuePtr, jsi::Runtime &runtime) { return std::move(*valuePtr); } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); @@ -365,8 +334,7 @@ inline jsi::Value getJsiValue(const std::vector &vec, return {runtime, array}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); @@ -374,8 +342,7 @@ inline jsi::Value getJsiValue(const std::vector &vec, return {runtime, array}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { array.setValueAtIndex(runtime, i, jsi::Value(vec[i])); @@ -383,18 +350,15 @@ inline jsi::Value getJsiValue(const std::vector &vec, return {runtime, array}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { - array.setValueAtIndex(runtime, i, - jsi::String::createFromUtf8(runtime, vec[i])); + array.setValueAtIndex(runtime, i, jsi::String::createFromUtf8(runtime, vec[i])); } return {runtime, array}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { array.setValueAtIndex(runtime, i, jsi::Value(vec[i])); @@ -402,8 +366,7 @@ inline jsi::Value getJsiValue(const std::vector &vec, return {runtime, array}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { array.setValueAtIndex(runtime, i, jsi::Value(static_cast(vec[i]))); @@ -411,9 +374,9 @@ inline jsi::Value getJsiValue(const std::vector &vec, return {runtime, array}; } -inline jsi::Value getJsiValue( - const rnexecutorch::models::pose_estimation::PersonKeypoints &keypoints, - jsi::Runtime &runtime) { +inline jsi::Value +getJsiValue(const rnexecutorch::models::pose_estimation::PersonKeypoints &keypoints, + jsi::Runtime &runtime) { jsi::Array array(runtime, keypoints.size()); for (size_t i = 0; i < keypoints.size(); ++i) { jsi::Object point(runtime); @@ -425,9 +388,9 @@ inline jsi::Value getJsiValue( } // Pose estimation: all detected people (vector of person keypoints) -inline jsi::Value getJsiValue( - const rnexecutorch::models::pose_estimation::PoseDetections &detections, - jsi::Runtime &runtime) { +inline jsi::Value +getJsiValue(const rnexecutorch::models::pose_estimation::PoseDetections &detections, + jsi::Runtime &runtime) { jsi::Array array(runtime, detections.size()); for (size_t i = 0; i < detections.size(); ++i) { array.setValueAtIndex(runtime, i, getJsiValue(detections[i], runtime)); @@ -437,9 +400,8 @@ inline jsi::Value getJsiValue( // Conditional as on android, size_t and uint64_t reduce to the same type, // introducing ambiguity -template && - !std::is_same_v>> +template && + !std::is_same_v>> inline jsi::Value getJsiValue(T val, jsi::Runtime &runtime) { return jsi::Value(static_cast(val)); } @@ -449,13 +411,9 @@ inline jsi::Value getJsiValue(uint64_t val, jsi::Runtime &runtime) { return {runtime, bigInt}; } -inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { - return {runtime, val}; -} +inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) { return {runtime, val}; } -inline jsi::Value getJsiValue(bool val, jsi::Runtime &runtime) { - return jsi::Value(val); -} +inline jsi::Value getJsiValue(bool val, jsi::Runtime &runtime) { return jsi::Value(val); } inline jsi::Value getJsiValue(const std::shared_ptr &buf, jsi::Runtime &runtime) { @@ -463,9 +421,8 @@ inline jsi::Value getJsiValue(const std::shared_ptr &buf, return {runtime, arrayBuffer}; } -inline jsi::Value -getJsiValue(const std::vector> &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector> &vec, + jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { jsi::ArrayBuffer arrayBuffer(runtime, vec[i]); @@ -474,16 +431,14 @@ getJsiValue(const std::vector> &vec, return {runtime, array}; } -inline jsi::Value getJsiValue(const std::vector &vec, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &vec, jsi::Runtime &runtime) { jsi::Array array(runtime, vec.size()); for (size_t i = 0; i < vec.size(); i++) { jsi::Object tensorObj(runtime); tensorObj.setProperty(runtime, "sizes", getJsiValue(vec[i].sizes, runtime)); - tensorObj.setProperty(runtime, "scalarType", - jsi::Value(static_cast(vec[i].scalarType))); + tensorObj.setProperty(runtime, "scalarType", jsi::Value(static_cast(vec[i].scalarType))); jsi::ArrayBuffer arrayBuffer(runtime, vec[i].dataPtr); tensorObj.setProperty(runtime, "dataPtr", arrayBuffer); @@ -497,9 +452,8 @@ inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) { return jsi::String::createFromUtf8(runtime, str); } -inline jsi::Value -getJsiValue(const std::unordered_map &map, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::unordered_map &map, + jsi::Runtime &runtime) { jsi::Object mapObj{runtime}; for (auto &[k, v] : map) { // The string_view keys must be null-terminated! @@ -508,8 +462,7 @@ getJsiValue(const std::unordered_map &map, return mapObj; } -inline jsi::Value getJsiValue(const utils::computer_vision::BBox &bbox, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const utils::computer_vision::BBox &bbox, jsi::Runtime &runtime) { jsi::Object obj(runtime); obj.setProperty(runtime, "x1", bbox.p1.x); obj.setProperty(runtime, "y1", bbox.p1.y); @@ -518,17 +471,15 @@ inline jsi::Value getJsiValue(const utils::computer_vision::BBox &bbox, return obj; } -inline jsi::Value getJsiValue( - const std::vector &detections, - jsi::Runtime &runtime) { +inline jsi::Value +getJsiValue(const std::vector &detections, + jsi::Runtime &runtime) { jsi::Array array(runtime, detections.size()); for (std::size_t i = 0; i < detections.size(); ++i) { jsi::Object detection(runtime); - detection.setProperty(runtime, "bbox", - getJsiValue(detections[i].bbox, runtime)); - detection.setProperty( - runtime, "label", - jsi::String::createFromUtf8(runtime, detections[i].label)); + detection.setProperty(runtime, "bbox", getJsiValue(detections[i].bbox, runtime)); + detection.setProperty(runtime, "label", + jsi::String::createFromUtf8(runtime, detections[i].label)); detection.setProperty(runtime, "score", detections[i].score); array.setValueAtIndex(runtime, i, detection); } @@ -536,22 +487,18 @@ inline jsi::Value getJsiValue( } inline jsi::Value -getJsiValue(const std::vector - &instances, +getJsiValue(const std::vector &instances, jsi::Runtime &runtime) { jsi::Array array(runtime, instances.size()); for (std::size_t i = 0; i < instances.size(); ++i) { jsi::Object instance(runtime); - instance.setProperty(runtime, "bbox", - getJsiValue(instances[i].bbox, runtime)); + instance.setProperty(runtime, "bbox", getJsiValue(instances[i].bbox, runtime)); // Mask as Uint8Array - reuse existing OwningArrayBuffer jsi::ArrayBuffer arrayBuffer(runtime, instances[i].mask); - auto uint8ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Uint8Array"); - auto uint8Array = uint8ArrayCtor.callAsConstructor(runtime, arrayBuffer) - .getObject(runtime); + auto uint8ArrayCtor = runtime.global().getPropertyAsFunction(runtime, "Uint8Array"); + auto uint8Array = uint8ArrayCtor.callAsConstructor(runtime, arrayBuffer).getObject(runtime); instance.setProperty(runtime, "mask", uint8Array); instance.setProperty(runtime, "maskWidth", instances[i].maskWidth); @@ -566,19 +513,17 @@ getJsiValue(const std::vector return array; } -inline jsi::Value -getJsiValue(const std::vector &detections, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &detections, + jsi::Runtime &runtime) { auto jsiDetections = jsi::Array(runtime, detections.size()); for (size_t i = 0; i < detections.size(); ++i) { const auto &detection = detections[i]; auto jsiDetectionObject = jsi::Object(runtime); - jsiDetectionObject.setProperty(runtime, "bbox", - getJsiValue(detection.bbox, runtime)); - jsiDetectionObject.setProperty( - runtime, "text", jsi::String::createFromUtf8(runtime, detection.text)); + jsiDetectionObject.setProperty(runtime, "bbox", getJsiValue(detection.bbox, runtime)); + jsiDetectionObject.setProperty(runtime, "text", + jsi::String::createFromUtf8(runtime, detection.text)); jsiDetectionObject.setProperty(runtime, "score", detection.score); jsiDetections.setValueAtIndex(runtime, i, jsiDetectionObject); @@ -588,8 +533,7 @@ getJsiValue(const std::vector &detections, } inline jsi::Value -getJsiValue(const std::vector - &speechSegments, +getJsiValue(const std::vector &speechSegments, jsi::Runtime &runtime) { auto jsiSegments = jsi::Array(runtime, speechSegments.size()); for (size_t i = 0; i < speechSegments.size(); i++) { @@ -602,17 +546,14 @@ getJsiValue(const std::vector return jsiSegments; } -inline jsi::Value getJsiValue( - const std::vector &entities, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const std::vector &entities, + jsi::Runtime &runtime) { auto jsiEntities = jsi::Array(runtime, entities.size()); for (size_t i = 0; i < entities.size(); i++) { const auto &e = entities[i]; auto obj = jsi::Object(runtime); - obj.setProperty(runtime, "label", - jsi::String::createFromUtf8(runtime, e.label)); - obj.setProperty(runtime, "text", - jsi::String::createFromUtf8(runtime, e.text)); + obj.setProperty(runtime, "label", jsi::String::createFromUtf8(runtime, e.label)); + obj.setProperty(runtime, "text", jsi::String::createFromUtf8(runtime, e.text)); obj.setProperty(runtime, "startToken", e.startToken); obj.setProperty(runtime, "endToken", e.endToken); jsiEntities.setValueAtIndex(runtime, i, obj); @@ -628,8 +569,7 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) { std::string segText; for (const auto &w : seg.words) segText += w.content; - obj.setProperty(runtime, "text", - jsi::String::createFromUtf8(runtime, segText)); + obj.setProperty(runtime, "text", jsi::String::createFromUtf8(runtime, segText)); obj.setProperty(runtime, "avgLogprob", seg.avgLogprob); obj.setProperty(runtime, "compressionRatio", seg.compressionRatio); @@ -638,11 +578,9 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) { jsi::Array wordsArray(runtime, seg.words.size()); for (size_t i = 0; i < seg.words.size(); ++i) { jsi::Object wordObj(runtime); - wordObj.setProperty( - runtime, "word", - jsi::String::createFromUtf8(runtime, seg.words[i].content)); - wordObj.setProperty(runtime, "start", - static_cast(seg.words[i].start)); + wordObj.setProperty(runtime, "word", + jsi::String::createFromUtf8(runtime, seg.words[i].content)); + wordObj.setProperty(runtime, "start", static_cast(seg.words[i].start)); wordObj.setProperty(runtime, "end", static_cast(seg.words[i].end)); wordsArray.setValueAtIndex(runtime, i, wordObj); @@ -658,41 +596,33 @@ inline jsi::Value getJsiValue(const Segment &seg, jsi::Runtime &runtime) { return obj; } -inline jsi::Value getJsiValue(const TranscriptionResult &result, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const TranscriptionResult &result, jsi::Runtime &runtime) { jsi::Object obj(runtime); - obj.setProperty(runtime, "text", - jsi::String::createFromUtf8(runtime, result.text)); + obj.setProperty(runtime, "text", jsi::String::createFromUtf8(runtime, result.text)); if (!result.segments.empty() || !result.language.empty()) { - obj.setProperty(runtime, "task", - jsi::String::createFromUtf8(runtime, result.task)); + obj.setProperty(runtime, "task", jsi::String::createFromUtf8(runtime, result.task)); if (!result.language.empty()) { - obj.setProperty(runtime, "language", - jsi::String::createFromUtf8(runtime, result.language)); + obj.setProperty(runtime, "language", jsi::String::createFromUtf8(runtime, result.language)); } obj.setProperty(runtime, "duration", result.duration); jsi::Array segmentsArray(runtime, result.segments.size()); for (size_t i = 0; i < result.segments.size(); ++i) { - segmentsArray.setValueAtIndex(runtime, i, - getJsiValue(result.segments[i], runtime)); + segmentsArray.setValueAtIndex(runtime, i, getJsiValue(result.segments[i], runtime)); } obj.setProperty(runtime, "segments", segmentsArray); } return obj; } -inline jsi::Value -getJsiValue(const models::style_transfer::PixelDataResult &result, - jsi::Runtime &runtime) { +inline jsi::Value getJsiValue(const models::style_transfer::PixelDataResult &result, + jsi::Runtime &runtime) { jsi::Object obj(runtime); auto arrayBuffer = jsi::ArrayBuffer(runtime, result.dataPtr); - auto uint8ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Uint8Array"); - auto uint8Array = - uint8ArrayCtor.callAsConstructor(runtime, arrayBuffer).getObject(runtime); + auto uint8ArrayCtor = runtime.global().getPropertyAsFunction(runtime, "Uint8Array"); + auto uint8Array = uint8ArrayCtor.callAsConstructor(runtime, arrayBuffer).getObject(runtime); obj.setProperty(runtime, "dataPtr", uint8Array); auto sizesArray = jsi::Array(runtime, 3); @@ -701,45 +631,35 @@ getJsiValue(const models::style_transfer::PixelDataResult &result, sizesArray.setValueAtIndex(runtime, 2, jsi::Value(result.channels)); obj.setProperty(runtime, "sizes", sizesArray); - obj.setProperty(runtime, "scalarType", - jsi::Value(static_cast(ScalarType::Byte))); + obj.setProperty(runtime, "scalarType", jsi::Value(static_cast(ScalarType::Byte))); return obj; } -inline jsi::Value getJsiValue( - const rnexecutorch::models::semantic_segmentation::SegmentationResult - &result, - jsi::Runtime &runtime) { +inline jsi::Value +getJsiValue(const rnexecutorch::models::semantic_segmentation::SegmentationResult &result, + jsi::Runtime &runtime) { jsi::Object dict(runtime); auto argmaxArrayBuffer = jsi::ArrayBuffer(runtime, result.argmax); - auto int32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Int32Array"); - auto int32Array = int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer) - .getObject(runtime); + auto int32ArrayCtor = runtime.global().getPropertyAsFunction(runtime, "Int32Array"); + auto int32Array = int32ArrayCtor.callAsConstructor(runtime, argmaxArrayBuffer).getObject(runtime); dict.setProperty(runtime, "ARGMAX", int32Array); for (auto &[classLabel, owningBuffer] : *result.classBuffers) { auto classArrayBuffer = jsi::ArrayBuffer(runtime, owningBuffer); - auto float32ArrayCtor = - runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32ArrayCtor = runtime.global().getPropertyAsFunction(runtime, "Float32Array"); auto float32Array = - float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer) - .getObject(runtime); - dict.setProperty(runtime, jsi::String::createFromAscii(runtime, classLabel), - float32Array); + float32ArrayCtor.callAsConstructor(runtime, classArrayBuffer).getObject(runtime); + dict.setProperty(runtime, jsi::String::createFromAscii(runtime, classLabel), float32Array); } return dict; } -inline jsi::Value -getJsiValue(const models::style_transfer::StyleTransferResult &result, - jsi::Runtime &runtime) { - return std::visit( - [&runtime](const auto &value) { return getJsiValue(value, runtime); }, - result); +inline jsi::Value getJsiValue(const models::style_transfer::StyleTransferResult &result, + jsi::Runtime &runtime) { + return std::visit([&runtime](const auto &value) { return getJsiValue(value, runtime); }, result); } } // namespace rnexecutorch::jsi_conversion diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index cb631e5fba..137128fc00 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -29,15 +29,12 @@ namespace rnexecutorch { -inline jsi::Value makeRnExecutorchErrorValue(jsi::Runtime &runtime, - int32_t code, +inline jsi::Value makeRnExecutorchErrorValue(jsi::Runtime &runtime, int32_t code, const std::string &message) { - auto errorObj = - runtime.global() - .getPropertyAsFunction(runtime, "Error") - .callAsConstructor(runtime, - jsi::String::createFromUtf8(runtime, message)) - .asObject(runtime); + auto errorObj = runtime.global() + .getPropertyAsFunction(runtime, "Error") + .callAsConstructor(runtime, jsi::String::createFromUtf8(runtime, message)) + .asObject(runtime); errorObj.setProperty(runtime, "code", code); return jsi::Value(std::move(errorObj)); } @@ -48,212 +45,169 @@ template class ModelHostObject : public JsiHostObject { std::shared_ptr callInvoker) : model(model), callInvoker(callInvoker) { if constexpr (meta::DerivedFromOrSameAs) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::forwardJS>, - "forward")); + promiseHostFunction<&Model::forwardJS>, "forward")); addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, promiseHostFunction<&Model::getInputShape>, - "getInputShape")); + ModelHostObject, promiseHostFunction<&Model::getInputShape>, "getInputShape")); } // LLM::generate and LLM::generateMultimodal registered explicitly below - if constexpr (meta::HasGenerate && - !meta::SameAs) { + if constexpr (meta::HasGenerate && !meta::SameAs) { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::generate>, - "generate")); + promiseHostFunction<&Model::generate>, "generate")); } if constexpr (meta::HasEncode) { - addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::encode>, + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::encode>, "encode")); } if constexpr (meta::HasDecode) { - addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::decode>, + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::decode>, "decode")); } if constexpr (meta::SameAs) { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - synchronousHostFunction<&Model::unload>, - "unload")); + synchronousHostFunction<&Model::unload>, "unload")); addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::transcribe>, - "transcribe")); + promiseHostFunction<&Model::transcribe>, "transcribe")); - addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::stream>, + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::stream>, "stream")); addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamInsert>, - "streamInsert")); + ModelHostObject, synchronousHostFunction<&Model::streamInsert>, "streamInsert")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamStop>, - "streamStop")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::streamStop>, "streamStop")); } - if constexpr (meta::SameAs) { - addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::stream>, + if constexpr (meta::SameAs) { + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::stream>, "stream")); addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamInsert>, - "streamInsert")); + ModelHostObject, synchronousHostFunction<&Model::streamInsert>, "streamInsert")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamStop>, - "streamStop")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::streamStop>, "streamStop")); } if constexpr (meta::SameAs) { - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, promiseHostFunction<&Model::getVocabSize>, - "getVocabSize")); addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::idToToken>, - "idToToken")); + promiseHostFunction<&Model::getVocabSize>, "getVocabSize")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::idToToken>, "idToToken")); addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::tokenToId>, - "tokenToId")); + promiseHostFunction<&Model::tokenToId>, "tokenToId")); } if constexpr (meta::SameAs) { addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::generate>, - "generate")); + promiseHostFunction<&Model::generate>, "generate")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::interrupt>, - "interrupt")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::interrupt>, "interrupt")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::getGeneratedTokenCount>, - "getGeneratedTokenCount")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::getGeneratedTokenCount>, + "getGeneratedTokenCount")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::getPromptTokenCount>, - "getPromptTokenCount")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::getPromptTokenCount>, + "getPromptTokenCount")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::countTextTokens>, "countTextTokens")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::countTextTokens>, + "countTextTokens")); - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - synchronousHostFunction<&Model::setCountInterval>, - "setCountInterval")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::setCountInterval>, + "setCountInterval")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::setTimeInterval>, "setTimeInterval")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::setTimeInterval>, + "setTimeInterval")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::setTemperature>, "setTemperature")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::setTemperature>, + "setTemperature")); addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - synchronousHostFunction<&Model::setTopp>, - "setTopp")); + synchronousHostFunction<&Model::setTopp>, "setTopp")); addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - synchronousHostFunction<&Model::setMinP>, - "setMinP")); + synchronousHostFunction<&Model::setMinP>, "setMinP")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::setRepetitionPenalty>, - "setRepetitionPenalty")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::setRepetitionPenalty>, + "setRepetitionPenalty")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::getMaxContextLength>, - "getMaxContextLength")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::getMaxContextLength>, + "getMaxContextLength")); - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - synchronousHostFunction<&Model::reset>, - "reset")); + synchronousHostFunction<&Model::reset>, "reset")); - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::generateMultimodal>, - "generateMultimodal")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generateMultimodal>, + "generateMultimodal")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, - synchronousHostFunction<&Model::getVisualTokenCount>, - "getVisualTokenCount")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::getVisualTokenCount>, + "getVisualTokenCount")); } if constexpr (meta::SameAs) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::interrupt>, - "interrupt")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::interrupt>, "interrupt")); } if constexpr (meta::SameAs) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } if constexpr (meta::SameAs) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } if constexpr (meta::SameAs) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); - addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::stream>, + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, promiseHostFunction<&Model::stream>, "stream")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + synchronousHostFunction<&Model::streamStop>, "streamStop")); addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamStop>, - "streamStop")); - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamInsert>, - "streamInsert")); + ModelHostObject, synchronousHostFunction<&Model::streamInsert>, "streamInsert")); addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, synchronousHostFunction<&Model::streamFlush>, - "streamFlush")); + ModelHostObject, synchronousHostFunction<&Model::streamFlush>, "streamFlush")); } if constexpr (meta::HasGenerateFromString) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::generateFromString>, - "generateFromString")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generateFromString>, + "generateFromString")); } if constexpr (meta::HasGenerateFromFrame) { - addFunctions(JSI_EXPORT_FUNCTION( - ModelHostObject, visionHostFunction<&Model::generateFromFrame>, - "generateFromFrame")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + visionHostFunction<&Model::generateFromFrame>, + "generateFromFrame")); } if constexpr (meta::HasGenerateFromPixels) { - addFunctions( - JSI_EXPORT_FUNCTION(ModelHostObject, - promiseHostFunction<&Model::generateFromPixels>, - "generateFromPixels")); + addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject, + promiseHostFunction<&Model::generateFromPixels>, + "generateFromPixels")); } } @@ -263,8 +217,7 @@ template class ModelHostObject : public JsiHostObject { constexpr std::size_t functionArgCount = meta::getArgumentCount(FnPtr); if (functionArgCount != count) { std::stringstream ss; - ss << "Argument count mismatch, was expecting: " << functionArgCount - << " but got: " << count; + ss << "Argument count mismatch, was expecting: " << functionArgCount << " but got: " << count; const auto errorMessage = ss.str(); throw jsi::JSError(runtime, errorMessage); } @@ -272,20 +225,19 @@ template class ModelHostObject : public JsiHostObject { try { auto argsConverted = meta::createArgsTupleFromJsi(FnPtr, args, runtime); - if constexpr (std::is_void_v) { + if constexpr (std::is_void_v) { // For void functions, just call the function and return undefined std::apply(std::bind_front(FnPtr, model), std::move(argsConverted)); return jsi::Value::undefined(); } else { // For non-void functions, capture the result and convert it - auto result = - std::apply(std::bind_front(FnPtr, model), std::move(argsConverted)); + auto result = std::apply(std::bind_front(FnPtr, model), std::move(argsConverted)); return jsi_conversion::getJsiValue(std::move(result), runtime); } } catch (const RnExecutorchError &e) { - throw jsi::JSError(runtime, makeRnExecutorchErrorValue( - runtime, e.getNumericCode(), e.what())); + throw jsi::JSError(runtime, + makeRnExecutorchErrorValue(runtime, e.getNumericCode(), e.what())); } catch (const std::exception &e) { throw jsi::JSError(runtime, e.what()); } catch (...) { @@ -314,8 +266,7 @@ template class ModelHostObject : public JsiHostObject { * */ template JSI_HOST_FUNCTION(visionHostFunction) { - constexpr std::size_t cppArgCount = - meta::FunctionTraits::arity; + constexpr std::size_t cppArgCount = meta::FunctionTraits::arity; constexpr std::size_t expectedJsArgs = cppArgCount - 1; if (count != expectedJsArgs) { @@ -324,35 +275,31 @@ template class ModelHostObject : public JsiHostObject { try { auto dummyFuncPtr = &meta::TailSignature::dummy; - auto tailArgsTuple = - meta::createArgsTupleFromJsi(dummyFuncPtr, args + 1, runtime); + auto tailArgsTuple = meta::createArgsTupleFromJsi(dummyFuncPtr, args + 1, runtime); - using ReturnType = - typename meta::FunctionTraits::return_type; + using ReturnType = typename meta::FunctionTraits::return_type; if constexpr (std::is_void_v) { std::apply( [&](auto &&...tailArgs) { - (model.get()->*FnPtr)( - runtime, args[0], - std::forward(tailArgs)...); + (model.get()->*FnPtr)(runtime, args[0], + std::forward(tailArgs)...); }, std::move(tailArgsTuple)); return jsi::Value::undefined(); } else { auto result = std::apply( [&](auto &&...tailArgs) { - return (model.get()->*FnPtr)( - runtime, args[0], - std::forward(tailArgs)...); + return (model.get()->*FnPtr)(runtime, args[0], + std::forward(tailArgs)...); }, std::move(tailArgsTuple)); return jsi_conversion::getJsiValue(std::move(result), runtime); } } catch (const RnExecutorchError &e) { - throw jsi::JSError(runtime, makeRnExecutorchErrorValue( - runtime, e.getNumericCode(), e.what())); + throw jsi::JSError(runtime, + makeRnExecutorchErrorValue(runtime, e.getNumericCode(), e.what())); } catch (const std::exception &e) { throw jsi::JSError(runtime, e.what()); } catch (...) { @@ -365,10 +312,8 @@ template class ModelHostObject : public JsiHostObject { // signature, and the return value is converted back to JSI before resolving. template JSI_HOST_FUNCTION(promiseHostFunction) { auto promise = Promise::createPromise( - runtime, callInvoker, - [this, count, args, &runtime](std::shared_ptr promise) { - constexpr std::size_t functionArgCount = - meta::getArgumentCount(FnPtr); + runtime, callInvoker, [this, count, args, &runtime](std::shared_ptr promise) { + constexpr std::size_t functionArgCount = meta::getArgumentCount(FnPtr); if (functionArgCount != count) { std::stringstream ss; ss << "Argument count mismatch, was expecting: " << functionArgCount @@ -379,64 +324,52 @@ template class ModelHostObject : public JsiHostObject { } try { - auto argsConverted = - meta::createArgsTupleFromJsi(FnPtr, args, runtime); + auto argsConverted = meta::createArgsTupleFromJsi(FnPtr, args, runtime); // We need to dispatch a thread if we want the function to be // asynchronous. In this thread all accesses to jsi::Runtime need to // be done via the callInvoker. - threads::GlobalThreadPool::detach( - [model = this->model, callInvoker = this->callInvoker, promise, - argsConverted = std::move(argsConverted)]() { - try { - if constexpr (std::is_void_v) { - // For void functions, just call the function and resolve - // with undefined - std::apply(std::bind_front(FnPtr, model), - std::move(argsConverted)); - callInvoker->invokeAsync( - [promise](jsi::Runtime &runtime) { - promise->resolve(jsi::Value::undefined()); - }); - } else { - // For non-void functions, capture the result and convert - // it - auto result = std::apply(std::bind_front(FnPtr, model), - std::move(argsConverted)); - // The result is copied. It should either be quickly - // copiable, or passed with a shared_ptr. - callInvoker->invokeAsync( - [promise, result](jsi::Runtime &runtime) { - promise->resolve(jsi_conversion::getJsiValue( - std::move(result), runtime)); - }); - } - } catch (const RnExecutorchError &e) { - auto code = e.getNumericCode(); - auto msg = std::string(e.what()); - callInvoker->invokeAsync( - [code, msg, promise](jsi::Runtime &runtime) { - promise->reject( - makeRnExecutorchErrorValue(runtime, code, msg)); - }); - return; - } catch (const std::exception &e) { - callInvoker->invokeAsync([e = std::move(e), promise]() { - promise->reject(std::string(e.what())); - }); - return; - } catch (...) { - callInvoker->invokeAsync([promise]() { - promise->reject(std::string("Unknown error")); - }); - return; - } + threads::GlobalThreadPool::detach([model = this->model, callInvoker = this->callInvoker, + promise, + argsConverted = std::move(argsConverted)]() { + try { + if constexpr (std::is_void_v) { + // For void functions, just call the function and resolve + // with undefined + std::apply(std::bind_front(FnPtr, model), std::move(argsConverted)); + callInvoker->invokeAsync([promise](jsi::Runtime &runtime) { + promise->resolve(jsi::Value::undefined()); + }); + } else { + // For non-void functions, capture the result and convert + // it + auto result = std::apply(std::bind_front(FnPtr, model), std::move(argsConverted)); + // The result is copied. It should either be quickly + // copiable, or passed with a shared_ptr. + callInvoker->invokeAsync([promise, result](jsi::Runtime &runtime) { + promise->resolve(jsi_conversion::getJsiValue(std::move(result), runtime)); + }); + } + } catch (const RnExecutorchError &e) { + auto code = e.getNumericCode(); + auto msg = std::string(e.what()); + callInvoker->invokeAsync([code, msg, promise](jsi::Runtime &runtime) { + promise->reject(makeRnExecutorchErrorValue(runtime, code, msg)); }); + return; + } catch (const std::exception &e) { + callInvoker->invokeAsync( + [e = std::move(e), promise]() { promise->reject(std::string(e.what())); }); + return; + } catch (...) { + callInvoker->invokeAsync( + [promise]() { promise->reject(std::string("Unknown error")); }); + return; + } + }); } catch (...) { - promise->reject(std::string( - "Couldn't parse JS arguments in a native function")); + promise->reject(std::string("Couldn't parse JS arguments in a native function")); } }); @@ -448,8 +381,8 @@ template class ModelHostObject : public JsiHostObject { model->unload(); thisValue.asObject(runtime).setExternalMemoryPressure(runtime, 0); } catch (const RnExecutorchError &e) { - throw jsi::JSError(runtime, makeRnExecutorchErrorValue( - runtime, e.getNumericCode(), e.what())); + throw jsi::JSError(runtime, + makeRnExecutorchErrorValue(runtime, e.getNumericCode(), e.what())); } catch (const std::exception &e) { throw jsi::JSError(runtime, e.what()); } catch (...) { diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.cpp b/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.cpp index d07434325c..73563d73b5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.cpp @@ -11,15 +11,15 @@ std::vector objects; #endif JsiHostObject::JsiHostObject() { - getters_ = std::make_unique>(); - functions_ = std::make_unique< - std::unordered_map>(); - setters_ = std::make_unique>(); + getters_ = std::make_unique< + std::unordered_map>(); + functions_ = + std::make_unique>(); + setters_ = + std::make_unique>(); #if JSI_DEBUG_ALLOCATIONS objects.push_back(this); @@ -41,8 +41,7 @@ JsiHostObject::~JsiHostObject() { std::vector JsiHostObject::getPropertyNames(jsi::Runtime &rt) { std::vector propertyNames; - propertyNames.reserve(getters_->size() + functions_->size() + - setters_->size()); + propertyNames.reserve(getters_->size() + functions_->size() + setters_->size()); for (const auto &it : *getters_) { propertyNames.push_back(jsi::PropNameID::forUtf8(rt, it.first)); @@ -59,8 +58,7 @@ std::vector JsiHostObject::getPropertyNames(jsi::Runtime &rt) { return propertyNames; } -jsi::Value JsiHostObject::get(jsi::Runtime &runtime, - const jsi::PropNameID &name) { +jsi::Value JsiHostObject::get(jsi::Runtime &runtime, const jsi::PropNameID &name) { auto nameAsString = name.utf8(runtime); auto &hostFunctionCache = hostFunctionCache_.get(runtime); @@ -82,13 +80,11 @@ jsi::Value JsiHostObject::get(jsi::Runtime &runtime, } auto dispatcher = - std::bind(function->second, reinterpret_cast(this), - std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3, std::placeholders::_4); + std::bind(function->second, reinterpret_cast(this), std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); return hostFunctionCache - .emplace(nameAsString, jsi::Function::createFromHostFunction( - runtime, name, 0, dispatcher)) + .emplace(nameAsString, jsi::Function::createFromHostFunction(runtime, name, 0, dispatcher)) .first->second.asFunction(runtime); } @@ -99,10 +95,9 @@ void JsiHostObject::set(jsi::Runtime &runtime, const jsi::PropNameID &name, auto setter = setters_->find(nameAsString); if (setter != setters_->end()) { - auto dispatcher = std::bind(setter->second, this, std::placeholders::_1, - std::placeholders::_2); + auto dispatcher = std::bind(setter->second, this, std::placeholders::_1, std::placeholders::_2); return dispatcher(runtime, value); } } -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.h b/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.h index 8b78f3f885..e10fd1dc0b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/JsiHostObject.h @@ -11,31 +11,27 @@ #include #include -#define JSI_HOST_FUNCTION(NAME) \ - jsi::Value NAME(jsi::Runtime &runtime, const jsi::Value &thisValue, \ - const jsi::Value *args, size_t count) +#define JSI_HOST_FUNCTION(NAME) \ + jsi::Value NAME(jsi::Runtime &runtime, const jsi::Value &thisValue, const jsi::Value *args, \ + size_t count) -#define JSI_EXPORT_FUNCTION(CLASS, FUNCTION, NAME) \ - std::make_pair( \ - NAME, \ - static_cast( \ - &CLASS::FUNCTION)) +#define JSI_EXPORT_FUNCTION(CLASS, FUNCTION, NAME) \ + std::make_pair(NAME, static_cast( \ + &CLASS::FUNCTION)) #define JSI_PROPERTY_GETTER(name) jsi::Value name(jsi::Runtime &runtime) -#define JSI_EXPORT_PROPERTY_GETTER(CLASS, FUNCTION) \ - std::make_pair(std::string(#FUNCTION), \ - static_cast( \ - &CLASS::FUNCTION)) +#define JSI_EXPORT_PROPERTY_GETTER(CLASS, FUNCTION) \ + std::make_pair(std::string(#FUNCTION), \ + static_cast(&CLASS::FUNCTION)) -#define JSI_PROPERTY_SETTER(name) \ - void name(jsi::Runtime &runtime, const jsi::Value &value) +#define JSI_PROPERTY_SETTER(name) void name(jsi::Runtime &runtime, const jsi::Value &value) -#define JSI_EXPORT_PROPERTY_SETTER(CLASS, FUNCTION) \ - std::make_pair(std::string(#FUNCTION), \ - static_cast(&CLASS::FUNCTION)) +#define JSI_EXPORT_PROPERTY_SETTER(CLASS, FUNCTION) \ + std::make_pair( \ + std::string(#FUNCTION), \ + static_cast(&CLASS::FUNCTION)) namespace rnexecutorch { @@ -50,38 +46,29 @@ class JsiHostObject : public jsi::HostObject { jsi::Value get(jsi::Runtime &runtime, const jsi::PropNameID &name) override; - void set(jsi::Runtime &runtime, const jsi::PropNameID &name, - const jsi::Value &value) override; + void set(jsi::Runtime &runtime, const jsi::PropNameID &name, const jsi::Value &value) override; - template void addGetters(Args... args) { - (getters_->insert(args), ...); - } + template void addGetters(Args... args) { (getters_->insert(args), ...); } - template void addSetters(Args... args) { - (setters_->insert(args), ...); - } + template void addSetters(Args... args) { (setters_->insert(args), ...); } - template void addFunctions(Args... args) { - (functions_->insert(args), ...); - } + template void addFunctions(Args... args) { (functions_->insert(args), ...); } protected: - std::unique_ptr> + std::unique_ptr> getters_; - std::unique_ptr< - std::unordered_map> + std::unique_ptr> functions_; - std::unique_ptr> + std::unique_ptr< + std::unordered_map> setters_; private: RuntimeAwareCache> hostFunctionCache_; }; -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h b/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h index a76fb02707..132ad8df28 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/OwningArrayBuffer.h @@ -23,9 +23,7 @@ class OwningArrayBuffer : public jsi::MutableBuffer { /** * @param size Size of the buffer in bytes. */ - OwningArrayBuffer(size_t size) : size_(size) { - data_ = new uint8_t[size_]; - } + OwningArrayBuffer(size_t size) : size_(size) { data_ = new uint8_t[size_]; } /** * @param data Pointer to the source data. * @param size Size of the data in bytes. @@ -54,4 +52,4 @@ class OwningArrayBuffer : public jsi::MutableBuffer { const size_t size_; }; -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.cpp b/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.cpp index 510890e344..aac16d20bc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.cpp @@ -2,11 +2,10 @@ namespace rnexecutorch { -Promise::Promise(jsi::Runtime &runtime, - std::shared_ptr callInvoker, +Promise::Promise(jsi::Runtime &runtime, std::shared_ptr callInvoker, jsi::Value resolver, jsi::Value rejecter) - : runtime(runtime), callInvoker(callInvoker), - _resolver(std::move(resolver)), _rejecter(std::move(rejecter)) {} + : runtime(runtime), callInvoker(callInvoker), _resolver(std::move(resolver)), + _rejecter(std::move(rejecter)) {} void Promise::resolve(jsi::Value &&result) { _resolver.asObject(runtime).asFunction(runtime).call(runtime, result); diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.h b/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.h index b278daa918..74ab3333cb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.h +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/Promise.h @@ -13,15 +13,13 @@ using namespace facebook; class Promise; template -concept PromiseRunFn = - std::invocable> && - std::same_as>, void>; +concept PromiseRunFn = std::invocable> && + std::same_as>, void>; class Promise { public: - Promise(jsi::Runtime &runtime, - std::shared_ptr callInvoker, jsi::Value resolver, - jsi::Value rejecter); + Promise(jsi::Runtime &runtime, std::shared_ptr callInvoker, + jsi::Value resolver, jsi::Value rejecter); Promise(const Promise &) = delete; Promise &operator=(const Promise &) = delete; @@ -36,22 +34,19 @@ class Promise { and be able to bind a lambda. */ template - static jsi::Value - createPromise(jsi::Runtime &runtime, - std::shared_ptr callInvoker, Fn &&run) { + static jsi::Value createPromise(jsi::Runtime &runtime, + std::shared_ptr callInvoker, Fn &&run) { // Get Promise ctor from global - auto promiseCtor = - runtime.global().getPropertyAsFunction(runtime, "Promise"); + auto promiseCtor = runtime.global().getPropertyAsFunction(runtime, "Promise"); auto promiseCallback = jsi::Function::createFromHostFunction( runtime, jsi::PropNameID::forUtf8(runtime, "PromiseCallback"), 2, - [run = std::move(run), - callInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue, - const jsi::Value *arguments, size_t count) -> jsi::Value { + [run = std::move(run), callInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue, + const jsi::Value *arguments, + size_t count) -> jsi::Value { // Call function auto promise = std::make_shared( - runtime, callInvoker, arguments[0].asObject(runtime), - arguments[1].asObject(runtime)); + runtime, callInvoker, arguments[0].asObject(runtime), arguments[1].asObject(runtime)); run(promise); return jsi::Value::undefined(); diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeAwareCache.h b/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeAwareCache.h index dc7c8d0ff5..c2c3ee2cf9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeAwareCache.h +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeAwareCache.h @@ -23,8 +23,7 @@ using namespace facebook; * managed by the runtime, accessing that portion of the memory after runtime is * deleted is the root cause of that crash). */ -template -class RuntimeAwareCache : public RuntimeLifecycleListener { +template class RuntimeAwareCache : public RuntimeLifecycleListener { public: void onRuntimeDestroyed(jsi::Runtime *rt) override { // A runtime has been destroyed, so destroy the related cache. @@ -55,4 +54,4 @@ class RuntimeAwareCache : public RuntimeLifecycleListener { std::unordered_map runtimeCaches_; }; -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.cpp b/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.cpp index a0d5465caf..bbe9b2ebf4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.cpp @@ -2,9 +2,7 @@ namespace rnexecutorch { -static std::unordered_map> - listeners; +static std::unordered_map> listeners; struct RuntimeLifecycleMonitorObject : public jsi::HostObject { jsi::Runtime *rt_; @@ -20,18 +18,16 @@ struct RuntimeLifecycleMonitorObject : public jsi::HostObject { } }; -void RuntimeLifecycleMonitor::addListener(jsi::Runtime &rt, - RuntimeLifecycleListener *listener) { +void RuntimeLifecycleMonitor::addListener(jsi::Runtime &rt, RuntimeLifecycleListener *listener) { auto listenersSet = listeners.find(&rt); if (listenersSet == listeners.end()) { // We install a global host object in the provided runtime, this way we can // use that host object destructor to get notified when the runtime is being // terminated. We use a unique name for the object as it gets saved with the // runtime's global object. - rt.global().setProperty( - rt, "__rnaudioapi_runtime_lifecycle_monitor", - jsi::Object::createFromHostObject( - rt, std::make_shared(&rt))); + rt.global().setProperty(rt, "__rnaudioapi_runtime_lifecycle_monitor", + jsi::Object::createFromHostObject( + rt, std::make_shared(&rt))); std::unordered_set newSet; newSet.insert(listener); listeners.emplace(&rt, std::move(newSet)); @@ -40,8 +36,7 @@ void RuntimeLifecycleMonitor::addListener(jsi::Runtime &rt, } } -void RuntimeLifecycleMonitor::removeListener( - jsi::Runtime &rt, RuntimeLifecycleListener *listener) { +void RuntimeLifecycleMonitor::removeListener(jsi::Runtime &rt, RuntimeLifecycleListener *listener) { auto listenersSet = listeners.find(&rt); if (listenersSet == listeners.end()) { // nothing to do here @@ -50,4 +45,4 @@ void RuntimeLifecycleMonitor::removeListener( } } -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.h b/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.h index 7d0e88bbb6..1d44ad3a1c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.h +++ b/packages/react-native-executorch/common/rnexecutorch/jsi/RuntimeLifecycleMonitor.h @@ -28,8 +28,7 @@ struct RuntimeLifecycleListener { */ struct RuntimeLifecycleMonitor { static void addListener(jsi::Runtime &rt, RuntimeLifecycleListener *listener); - static void removeListener(jsi::Runtime &rt, - RuntimeLifecycleListener *listener); + static void removeListener(jsi::Runtime &rt, RuntimeLifecycleListener *listener); }; -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ConstructorHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ConstructorHelpers.h index c3cdda6768..580e4b7942 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ConstructorHelpers.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ConstructorHelpers.h @@ -29,36 +29,29 @@ using namespace facebook; template struct ConstructorTraits; template -concept HasConstructorTraits = - requires { typename ConstructorTraits::arg_types; }; +concept HasConstructorTraits = requires { typename ConstructorTraits::arg_types; }; template struct is_constructible_from_tuple; template -struct is_constructible_from_tuple> - : std::is_constructible {}; +struct is_constructible_from_tuple> : std::is_constructible {}; template concept ConstructibleFromTuple = is_constructible_from_tuple::value; -template -struct last_element_is_call_invoker : std::false_type {}; +template struct last_element_is_call_invoker : std::false_type {}; -template -struct last_element_is_call_invoker> { +template struct last_element_is_call_invoker> { private: template static constexpr bool check() { return std::is_same_v>; } - template - static constexpr bool check_last() { + template static constexpr bool check_last() { return check_last(); } - template static constexpr bool check_last() { - return check(); - } + template static constexpr bool check_last() { return check(); } public: static constexpr bool value = sizeof...(Args) > 0 && check_last(); @@ -69,19 +62,18 @@ struct last_element_is_call_invoker> { // it wouldn't be defined, but we keep it for readability template concept ValidConstructorTraits = - HasConstructorTraits && - ConstructibleFromTuple::arg_types>; + HasConstructorTraits && ConstructibleFromTuple::arg_types>; template concept CallInvokerLastInConstructor = HasConstructorTraits && - last_element_is_call_invoker< - typename ConstructorTraits::arg_types>::value; + last_element_is_call_invoker::arg_types>::value; template -std::tuple fillConstructorTupleFromArgs( - std::index_sequence, const jsi::Value *args, jsi::Runtime &runtime, - std::shared_ptr jsCallInvoker) { +std::tuple +fillConstructorTupleFromArgs(std::index_sequence, const jsi::Value *args, + jsi::Runtime &runtime, + std::shared_ptr jsCallInvoker) { constexpr std::size_t lastIndex = sizeof...(Types) - 1; return std::make_tuple([&]() { if constexpr (I == lastIndex) { @@ -104,14 +96,12 @@ std::tuple fillConstructorTupleFromArgs( /// @return A tuple which can then be used to instantiate the class T. template requires ValidConstructorTraits && CallInvokerLastInConstructor -auto createConstructorArgsWithCallInvoker( - const jsi::Value *args, jsi::Runtime &runtime, - std::shared_ptr jsCallInvoker) { +auto createConstructorArgsWithCallInvoker(const jsi::Value *args, jsi::Runtime &runtime, + std::shared_ptr jsCallInvoker) { return std::apply( [&](auto... typeWrappers) { return fillConstructorTupleFromArgs( - std::index_sequence_for{}, args, runtime, - jsCallInvoker); + std::index_sequence_for{}, args, runtime, jsCallInvoker); }, typename ConstructorTraits::arg_types{}); } @@ -125,9 +115,9 @@ auto createConstructorArgsWithCallInvoker( * @note The Class must be fully declared or forward-declared before this macro * is invoked */ -#define REGISTER_CONSTRUCTOR(Class, ...) \ - template <> struct meta::ConstructorTraits { \ - using arg_types = std::tuple<__VA_ARGS__>; \ +#define REGISTER_CONSTRUCTOR(Class, ...) \ + template <> struct meta::ConstructorTraits { \ + using arg_types = std::tuple<__VA_ARGS__>; \ } -} // namespace rnexecutorch \ No newline at end of file +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ContainerHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ContainerHelpers.h index 941ddedbb2..5ad4c549d8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ContainerHelpers.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/ContainerHelpers.h @@ -19,14 +19,11 @@ * @note The macro prints the variable name, file, and line for easier * debugging. */ -#define CHECK_SIZE(container, expected) \ - if ((container).size() != (expected)) { \ - rnexecutorch::log(rnexecutorch::LOG_LEVEL::Error, \ - "Unexpected size for " #container " at ", \ - std::filesystem::path(__FILE__).filename().string(), \ - ":", __LINE__, ": expected ", (expected), " but got ", \ - (container).size()); \ - throw rnexecutorch::RnExecutorchError( \ - rnexecutorch::RnExecutorchErrorCode::WrongDimensions, \ - "Invalid shape of " #container); \ - } \ No newline at end of file +#define CHECK_SIZE(container, expected) \ + if ((container).size() != (expected)) { \ + rnexecutorch::log(rnexecutorch::LOG_LEVEL::Error, "Unexpected size for " #container " at ", \ + std::filesystem::path(__FILE__).filename().string(), ":", __LINE__, \ + ": expected ", (expected), " but got ", (container).size()); \ + throw rnexecutorch::RnExecutorchError(rnexecutorch::RnExecutorchErrorCode::WrongDimensions, \ + "Invalid shape of " #container); \ + } diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h index fde81e046d..eddea93071 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h @@ -21,11 +21,9 @@ constexpr std::size_t getArgumentCount(R (Model::*f)(Types...) const) { } template -std::tuple fillTupleFromArgs(std::index_sequence, - const jsi::Value *args, +std::tuple fillTupleFromArgs(std::index_sequence, const jsi::Value *args, jsi::Runtime &runtime) { - return std::tuple{ - jsi_conversion::getValue(args[I], runtime)...}; + return std::tuple{jsi_conversion::getValue(args[I], runtime)...}; } /** @@ -35,30 +33,24 @@ std::tuple fillTupleFromArgs(std::index_sequence, */ template -std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...), - const jsi::Value *args, +std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...), const jsi::Value *args, jsi::Runtime &runtime) { - return fillTupleFromArgs(std::index_sequence_for{}, args, - runtime); + return fillTupleFromArgs(std::index_sequence_for{}, args, runtime); } template -std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...) const, - const jsi::Value *args, +std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...) const, const jsi::Value *args, jsi::Runtime &runtime) { - return fillTupleFromArgs(std::index_sequence_for{}, args, - runtime); + return fillTupleFromArgs(std::index_sequence_for{}, args, runtime); } // Free function overload used by visionHostFunction: accepts a dummy free // function pointer whose parameter types (Rest...) are extracted by // TailSignature and converted from JSI args. template -std::tuple createArgsTupleFromJsi(void (*f)(Types...), - const jsi::Value *args, +std::tuple createArgsTupleFromJsi(void (*f)(Types...), const jsi::Value *args, jsi::Runtime &runtime) { - return fillTupleFromArgs(std::index_sequence_for{}, args, - runtime); + return fillTupleFromArgs(std::index_sequence_for{}, args, runtime); } // Extracts arity, return type, and argument types from a member function @@ -66,15 +58,13 @@ std::tuple createArgsTupleFromJsi(void (*f)(Types...), // JS argument count and invoke the correct return path. template struct FunctionTraits; -template -struct FunctionTraits { +template struct FunctionTraits { static constexpr std::size_t arity = sizeof...(Args); using return_type = R; using args_tuple = std::tuple; }; -template -struct FunctionTraits { +template struct FunctionTraits { static constexpr std::size_t arity = sizeof...(Args); using return_type = R; using args_tuple = std::tuple; @@ -86,14 +76,12 @@ struct FunctionTraits { // createArgsTupleFromJsi, while frameData at args[0] is passed manually. template struct TailSignature; -template +template struct TailSignature { static void dummy(Rest...) {} }; -template +template struct TailSignature { static void dummy(Rest...) {} }; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index 735b257f51..49547d2d34 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -12,10 +12,8 @@ using ::executorch::extension::module::Module; using ::executorch::runtime::Error; BaseModel::BaseModel(const std::string &modelSource, - std::shared_ptr callInvoker, - Module::LoadMode loadMode) - : callInvoker(callInvoker), - module_(std::make_unique(modelSource, loadMode)) { + std::shared_ptr callInvoker, Module::LoadMode loadMode) + : callInvoker(callInvoker), module_(std::make_unique(modelSource, loadMode)) { Error loadError = module_->load(); if (loadError != Error::Ok) { throw RnExecutorchError(loadError, "Failed to load model"); @@ -28,8 +26,7 @@ BaseModel::BaseModel(const std::string &modelSource, memorySizeLowerBound = std::filesystem::file_size(modelPath); } -std::vector BaseModel::getInputShape(std::string method_name, - int32_t index) const { +std::vector BaseModel::getInputShape(std::string method_name, int32_t index) const { if (!module_) { THROW_NOT_LOADED_ERROR(); } @@ -37,16 +34,14 @@ std::vector BaseModel::getInputShape(std::string method_name, auto method_meta = module_->method_meta(method_name); if (!method_meta.ok()) { throw RnExecutorchError(method_meta.error(), - "Failed to get metadata for method '" + - method_name + "'"); + "Failed to get metadata for method '" + method_name + "'"); } auto input_meta = method_meta->input_tensor_meta(index); if (!input_meta.ok()) { - throw RnExecutorchError( - input_meta.error(), - "Failed to get metadata for input tensor at index " + - std::to_string(index) + " in method '" + method_name + "'"); + throw RnExecutorchError(input_meta.error(), + "Failed to get metadata for input tensor at index " + + std::to_string(index) + " in method '" + method_name + "'"); } auto sizes = input_meta->sizes(); @@ -54,8 +49,7 @@ std::vector BaseModel::getInputShape(std::string method_name, return input_shape; } -std::vector> -BaseModel::getAllInputShapes(std::string methodName) const { +std::vector> BaseModel::getAllInputShapes(std::string methodName) const { if (!module_) { THROW_NOT_LOADED_ERROR(); } @@ -63,8 +57,7 @@ BaseModel::getAllInputShapes(std::string methodName) const { auto method_meta = module_->method_meta(methodName); if (!method_meta.ok()) { throw RnExecutorchError(method_meta.error(), - "Failed to get metadata for method '" + methodName + - "'"); + "Failed to get metadata for method '" + methodName + "'"); } std::vector> output; std::size_t numInputs = method_meta->num_inputs(); @@ -72,10 +65,9 @@ BaseModel::getAllInputShapes(std::string methodName) const { for (std::size_t input = 0; input < numInputs; ++input) { auto input_meta = method_meta->input_tensor_meta(input); if (!input_meta.ok()) { - throw RnExecutorchError( - input_meta.error(), - "Failed to get metadata for input tensor at index " + - std::to_string(input) + " in method '" + methodName + "'"); + throw RnExecutorchError(input_meta.error(), + "Failed to get metadata for input tensor at index " + + std::to_string(input) + " in method '" + methodName + "'"); } auto shape = input_meta->sizes(); output.emplace_back(std::vector(shape.begin(), shape.end())); @@ -86,8 +78,7 @@ BaseModel::getAllInputShapes(std::string methodName) const { /// @brief This method is a forward wrapper that is created solely to be exposed /// to JS. It is not meant to be used within C++. If you want to call forward /// from C++ on a BaseModel, please use BaseModel::forward. -std::vector -BaseModel::forwardJS(std::vector tensorViewVec) const { +std::vector BaseModel::forwardJS(std::vector tensorViewVec) const { if (!module_) { THROW_NOT_LOADED_ERROR(); } @@ -103,8 +94,7 @@ BaseModel::forwardJS(std::vector tensorViewVec) const { for (size_t i = 0; i < tensorViewVec.size(); i++) { const auto &currTensorView = tensorViewVec[i]; auto tensorPtr = - make_tensor_ptr(currTensorView.sizes, currTensorView.dataPtr, - currTensorView.scalarType); + make_tensor_ptr(currTensorView.sizes, currTensorView.dataPtr, currTensorView.scalarType); tensorPtrs.emplace_back(tensorPtr); evalues.emplace_back(*tensorPtr); // Dereference TensorPtr to get Tensor, // which implicitly converts to EValue @@ -123,8 +113,7 @@ BaseModel::forwardJS(std::vector tensorViewVec) const { auto &outputTensor = outputs[i].toTensor(); std::vector sizes = getTensorShape(outputTensor); size_t bufferSize = outputTensor.numel() * outputTensor.element_size(); - auto buffer = std::make_shared( - outputTensor.const_data_ptr(), bufferSize); + auto buffer = std::make_shared(outputTensor.const_data_ptr(), bufferSize); auto jsTensor = JSTensorViewOut(sizes, outputTensor.scalar_type(), buffer); output.emplace_back(jsTensor); } @@ -139,39 +128,33 @@ BaseModel::getMethodMeta(const std::string &methodName) const { return module_->method_meta(methodName); } -Result> -BaseModel::forward(const EValue &input_evalue) const { +Result> BaseModel::forward(const EValue &input_evalue) const { if (!module_) { THROW_NOT_LOADED_ERROR(); } return module_->forward(input_evalue); } -Result> -BaseModel::forward(const std::vector &input_evalues) const { +Result> BaseModel::forward(const std::vector &input_evalues) const { if (!module_) { THROW_NOT_LOADED_ERROR(); } return module_->forward(input_evalues); } -Result> -BaseModel::execute(const std::string &methodName, - const std::vector &input_value) const { +Result> BaseModel::execute(const std::string &methodName, + const std::vector &input_value) const { if (!module_) { THROW_NOT_LOADED_ERROR(); } return module_->execute(methodName, input_value); } -std::size_t BaseModel::getMemoryLowerBound() const noexcept { - return memorySizeLowerBound; -} +std::size_t BaseModel::getMemoryLowerBound() const noexcept { return memorySizeLowerBound; } void BaseModel::unload() noexcept { module_.reset(nullptr); } -std::vector -BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const { +std::vector BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) const { auto sizes = tensor.sizes(); return std::vector(sizes.begin(), sizes.end()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index 6d44976b90..fa2b817d54 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -22,26 +22,20 @@ class BaseModel { virtual ~BaseModel() = default; BaseModel(BaseModel &&) = default; BaseModel &operator=(BaseModel &&) = default; - BaseModel( - const std::string &modelSource, - std::shared_ptr callInvoker, - Module::LoadMode loadMode = Module::LoadMode::MmapUseMlockIgnoreErrors); + BaseModel(const std::string &modelSource, std::shared_ptr callInvoker, + Module::LoadMode loadMode = Module::LoadMode::MmapUseMlockIgnoreErrors); std::size_t getMemoryLowerBound() const noexcept; void unload() noexcept; [[nodiscard("Registered non-void function")]] std::vector getInputShape(std::string method_name, int32_t index) const; - std::vector> - getAllInputShapes(std::string methodName = "forward") const; + std::vector> getAllInputShapes(std::string methodName = "forward") const; [[nodiscard("Registered non-void function")]] std::vector forwardJS(std::vector tensorViewVec) const; Result> forward(const EValue &input_value) const; - Result> - forward(const std::vector &input_value) const; - Result> - execute(const std::string &methodName, - const std::vector &input_value) const; - Result - getMethodMeta(const std::string &methodName) const; + Result> forward(const std::vector &input_value) const; + Result> execute(const std::string &methodName, + const std::vector &input_value) const; + Result getMethodMeta(const std::string &methodName) const; protected: // If possible, models should not use the JS runtime to keep JSI internals @@ -54,11 +48,9 @@ class BaseModel { std::size_t memorySizeLowerBound{0}; private: - std::vector - getTensorShape(const executorch::aten::Tensor &tensor) const; + std::vector getTensorShape(const executorch::aten::Tensor &tensor) const; }; } // namespace models -REGISTER_CONSTRUCTOR(models::BaseModel, std::string, - std::shared_ptr); +REGISTER_CONSTRUCTOR(models::BaseModel, std::string, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp index cc9c862b32..b87e83b5c8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp @@ -25,8 +25,7 @@ cv::Size VisionModel::modelInputSize() const { modelInputShape_[modelInputShape_.size() - 2]); } -cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData) const { +cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const { cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); cv::Mat rgb; #ifdef __APPLE__ diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h index cf003948af..490a606e32 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -42,8 +42,7 @@ namespace models { */ class VisionModel : public BaseModel { public: - VisionModel(const std::string &modelSource, - std::shared_ptr callInvoker); + VisionModel(const std::string &modelSource, std::shared_ptr callInvoker); virtual ~VisionModel() = default; @@ -111,8 +110,7 @@ class VisionModel : public BaseModel { * * @note Does NOT acquire the inference mutex — caller is responsible */ - cv::Mat extractFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData) const; + cv::Mat extractFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) const; /** * @brief Extract cv::Mat from raw pixel data (TensorPtr) sent from @@ -150,7 +148,6 @@ class VisionModel : public BaseModel { // Register VisionModel constructor traits // Even though VisionModel is abstract, the metaprogramming system needs to know // its constructor signature for derived classes -REGISTER_CONSTRUCTOR(models::VisionModel, std::string, - std::shared_ptr); +REGISTER_CONSTRUCTOR(models::VisionModel, std::string, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index d31f78607f..d022d4629f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -9,24 +9,19 @@ namespace rnexecutorch::models::classification { -Classification::Classification(const std::string &modelSource, - std::vector normMean, - std::vector normStd, - std::vector labelNames, +Classification::Classification(const std::string &modelSource, std::vector normMean, + std::vector normStd, std::vector labelNames, std::shared_ptr callInvoker) - : VisionModel(modelSource, callInvoker), - labelNames_(std::move(labelNames)) { + : VisionModel(modelSource, callInvoker), labelNames_(std::move(labelNames)) { if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { - log(LOG_LEVEL::Warn, - "normMean must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normMean must have 3 elements — ignoring provided value."); } if (normStd.size() == 3) { normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); } else if (!normStd.empty()) { - log(LOG_LEVEL::Warn, - "normStd must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normStd must have 3 elements — ignoring provided value."); } auto inputShapes = getAllInputShapes(); @@ -41,23 +36,19 @@ Classification::Classification(const std::string &modelSource, "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", modelInputShape_.size()); - throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, - errorMessage); + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } } -std::unordered_map -Classification::runInference(cv::Mat image) { +std::unordered_map Classification::runInference(cv::Mat image) { std::scoped_lock lock(inference_mutex_); cv::Mat preprocessed = preprocess(image); - auto inputTensor = - (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix( - modelInputShape_, preprocessed, *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(modelInputShape_, - preprocessed); + auto inputTensor = (normMean_ && normStd_) + ? image_processing::getTensorFromMatrix(modelInputShape_, preprocessed, + *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); @@ -75,8 +66,7 @@ Classification::generateFromString(std::string imageSource) { } std::unordered_map -Classification::generateFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData) { +Classification::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { cv::Mat frame = extractFromFrame(runtime, frameData); return runInference(frame); } @@ -88,21 +78,18 @@ Classification::generateFromPixels(JSTensorViewIn pixelData) { return runInference(image); } -std::unordered_map -Classification::postprocess(const Tensor &tensor) { - std::span resultData( - static_cast(tensor.const_data_ptr()), tensor.numel()); +std::unordered_map Classification::postprocess(const Tensor &tensor) { + std::span resultData(static_cast(tensor.const_data_ptr()), + tensor.numel()); std::vector resultVec(resultData.begin(), resultData.end()); if (resultVec.size() != labelNames_.size()) { char errorMessage[100]; - std::snprintf( - errorMessage, sizeof(errorMessage), - "Unexpected classification output size, was expecting: %zu classes " - "but got: %zu classes", - labelNames_.size(), resultVec.size()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput, - errorMessage); + std::snprintf(errorMessage, sizeof(errorMessage), + "Unexpected classification output size, was expecting: %zu classes " + "but got: %zu classes", + labelNames_.size(), resultVec.size()); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput, errorMessage); } numerical::softmax(resultVec); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h index 2ea0e17bbb..c7f02799bc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h @@ -18,20 +18,16 @@ using executorch::extension::TensorPtr; class Classification : public VisionModel { public: Classification(const std::string &modelSource, std::vector normMean, - std::vector normStd, - std::vector labelNames, + std::vector normStd, std::vector labelNames, std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] std::unordered_map< - std::string_view, float> + [[nodiscard("Registered non-void function")]] std::unordered_map generateFromString(std::string imageSource); - [[nodiscard("Registered non-void function")]] std::unordered_map< - std::string_view, float> + [[nodiscard("Registered non-void function")]] std::unordered_map generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); - [[nodiscard("Registered non-void function")]] std::unordered_map< - std::string_view, float> + [[nodiscard("Registered non-void function")]] std::unordered_map generateFromPixels(JSTensorViewIn pixelData); private: @@ -45,8 +41,7 @@ class Classification : public VisionModel { }; } // namespace models::classification -REGISTER_CONSTRUCTOR(models::classification::Classification, std::string, - std::vector, std::vector, - std::vector, +REGISTER_CONSTRUCTOR(models::classification::Classification, std::string, std::vector, + std::vector, std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp index bf291136c1..fe3618851d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp @@ -11,8 +11,8 @@ BaseEmbeddings::BaseEmbeddings(const std::string &modelSource, std::shared_ptr BaseEmbeddings::postprocess(const Result> &forwardResult) { auto forwardResultTensor = forwardResult->at(0).toTensor(); - auto buffer = std::make_shared( - forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes()); + auto buffer = std::make_shared(forwardResultTensor.const_data_ptr(), + forwardResultTensor.nbytes()); return buffer; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h index 216d6bf8ce..51ee8ec379 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h @@ -6,12 +6,10 @@ namespace rnexecutorch::models::embeddings { class BaseEmbeddings : public BaseModel { public: - BaseEmbeddings(const std::string &modelSource, - std::shared_ptr callInvoker); + BaseEmbeddings(const std::string &modelSource, std::shared_ptr callInvoker); protected: - std::shared_ptr - postprocess(const Result> &forwardResult); + std::shared_ptr postprocess(const Result> &forwardResult); }; }; // namespace rnexecutorch::models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp index 03937334cc..0ec9fc1827 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp @@ -6,9 +6,8 @@ namespace rnexecutorch::models::embeddings { -ImageEmbeddings::ImageEmbeddings( - const std::string &modelSource, - std::shared_ptr callInvoker) +ImageEmbeddings::ImageEmbeddings(const std::string &modelSource, + std::shared_ptr callInvoker) : VisionModel(modelSource, callInvoker) { auto inputTensors = getAllInputShapes(); if (inputTensors.size() == 0) { @@ -22,31 +21,27 @@ ImageEmbeddings::ImageEmbeddings( "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", modelInputShape_.size()); - throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, - errorMessage); + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, errorMessage); } } -std::shared_ptr -ImageEmbeddings::runInference(cv::Mat image) { +std::shared_ptr ImageEmbeddings::runInference(cv::Mat image) { std::scoped_lock lock(inference_mutex_); cv::Mat preprocessed = preprocess(image); - auto inputTensor = - image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); + auto inputTensor = image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); auto forwardResultTensor = forwardResult->at(0).toTensor(); - return std::make_shared( - forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes()); + return std::make_shared(forwardResultTensor.const_data_ptr(), + forwardResultTensor.nbytes()); } -std::shared_ptr -ImageEmbeddings::generateFromString(std::string imageSource) { +std::shared_ptr ImageEmbeddings::generateFromString(std::string imageSource) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Mat imageRGB; @@ -55,15 +50,13 @@ ImageEmbeddings::generateFromString(std::string imageSource) { return runInference(imageRGB); } -std::shared_ptr -ImageEmbeddings::generateFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData) { +std::shared_ptr ImageEmbeddings::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { cv::Mat frame = extractFromFrame(runtime, frameData); return runInference(frame); } -std::shared_ptr -ImageEmbeddings::generateFromPixels(JSTensorViewIn pixelData) { +std::shared_ptr ImageEmbeddings::generateFromPixels(JSTensorViewIn pixelData) { cv::Mat image = extractFromPixels(pixelData); return runInference(image); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h index 3a20301724..0a3d15477f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h @@ -16,19 +16,15 @@ using executorch::runtime::EValue; class ImageEmbeddings final : public VisionModel { public: - ImageEmbeddings(const std::string &modelSource, - std::shared_ptr callInvoker); + ImageEmbeddings(const std::string &modelSource, std::shared_ptr callInvoker); - [[nodiscard( - "Registered non-void function")]] std::shared_ptr + [[nodiscard("Registered non-void function")]] std::shared_ptr generateFromString(std::string imageSource); - [[nodiscard( - "Registered non-void function")]] std::shared_ptr + [[nodiscard("Registered non-void function")]] std::shared_ptr generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData); - [[nodiscard( - "Registered non-void function")]] std::shared_ptr + [[nodiscard("Registered non-void function")]] std::shared_ptr generateFromPixels(JSTensorViewIn pixelData); private: diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp index ba2c3243b2..3f3ec3aef6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp @@ -8,12 +8,10 @@ namespace rnexecutorch::models::embeddings { using namespace executorch::extension; -TextEmbeddings::TextEmbeddings(const std::string &modelSource, - const std::string &tokenizerSource, +TextEmbeddings::TextEmbeddings(const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker) : BaseEmbeddings(modelSource, callInvoker), - tokenizer( - std::make_unique(tokenizerSource, callInvoker)) {} + tokenizer(std::make_unique(tokenizerSource, callInvoker)) {} TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) { auto inputIds = tokenizer->encode(input); @@ -40,20 +38,16 @@ void TextEmbeddings::unload() noexcept { BaseModel::unload(); } -std::shared_ptr -TextEmbeddings::generate(const std::string input) { +std::shared_ptr TextEmbeddings::generate(const std::string input) { std::scoped_lock lock(inference_mutex_); auto preprocessed = preprocess(input); - std::vector tokenIdsShape = { - 1, static_cast(preprocessed.inputIds.size())}; - std::vector attnMaskShape = { - 1, static_cast(preprocessed.attentionMask.size())}; + std::vector tokenIdsShape = {1, static_cast(preprocessed.inputIds.size())}; + std::vector attnMaskShape = {1, static_cast(preprocessed.attentionMask.size())}; - auto tokenIds = make_tensor_ptr(tokenIdsShape, preprocessed.inputIds.data(), - ScalarType::Long); - auto attnMask = make_tensor_ptr( - attnMaskShape, preprocessed.attentionMask.data(), ScalarType::Long); + auto tokenIds = make_tensor_ptr(tokenIdsShape, preprocessed.inputIds.data(), ScalarType::Long); + auto attnMask = + make_tensor_ptr(attnMaskShape, preprocessed.attentionMask.data(), ScalarType::Long); auto forwardResult = BaseModel::forward({tokenIds, attnMask}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h index 93d0988c04..7ff6599cfe 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h @@ -15,11 +15,9 @@ struct TokenIdsWithAttentionMask { class TextEmbeddings final : public BaseEmbeddings { public: - TextEmbeddings(const std::string &modelSource, - const std::string &tokenizerSource, + TextEmbeddings(const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker); - [[nodiscard( - "Registered non-void function")]] std::shared_ptr + [[nodiscard("Registered non-void function")]] std::shared_ptr generate(const std::string input); void unload() noexcept; @@ -31,6 +29,6 @@ class TextEmbeddings final : public BaseEmbeddings { }; } // namespace models::embeddings -REGISTER_CONSTRUCTOR(models::embeddings::TextEmbeddings, std::string, - std::string, std::shared_ptr); +REGISTER_CONSTRUCTOR(models::embeddings::TextEmbeddings, std::string, std::string, + std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp index 24ef7a8e22..0d9084f698 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp @@ -12,23 +12,21 @@ namespace rnexecutorch::models::instance_segmentation { -BaseInstanceSegmentation::BaseInstanceSegmentation( - const std::string &modelSource, std::vector normMean, - std::vector normStd, bool applyNMS, - std::shared_ptr callInvoker) +BaseInstanceSegmentation::BaseInstanceSegmentation(const std::string &modelSource, + std::vector normMean, + std::vector normStd, bool applyNMS, + std::shared_ptr callInvoker) : VisionModel(modelSource, callInvoker), applyNMS_(applyNMS) { if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { - log(LOG_LEVEL::Warn, - "normMean must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normMean must have 3 elements — ignoring provided value."); } if (normStd.size() == 3) { normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); } else if (!normStd.empty()) { - log(LOG_LEVEL::Warn, - "normStd must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normStd must have 3 elements — ignoring provided value."); } } @@ -47,17 +45,15 @@ cv::Size BaseInstanceSegmentation::modelInputSize() const { TensorPtr BaseInstanceSegmentation::buildInputTensor(const cv::Mat &image) { cv::Mat preprocessed = preprocess(image); return (normMean_.has_value() && normStd_.has_value()) - ? image_processing::getTensorFromMatrix( - modelInputShape_, preprocessed, normMean_.value(), - normStd_.value()) - : image_processing::getTensorFromMatrix(modelInputShape_, - preprocessed); + ? image_processing::getTensorFromMatrix(modelInputShape_, preprocessed, + normMean_.value(), normStd_.value()) + : image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); } std::vector BaseInstanceSegmentation::runInference( - const cv::Mat &image, double confidenceThreshold, double iouThreshold, - int32_t maxInstances, const std::vector &classIndices, - bool returnMaskAtOriginalResolution, const std::string &methodName) { + const cv::Mat &image, double confidenceThreshold, double iouThreshold, int32_t maxInstances, + const std::vector &classIndices, bool returnMaskAtOriginalResolution, + const std::string &methodName) { std::scoped_lock lock(inference_mutex_); @@ -66,8 +62,7 @@ std::vector BaseInstanceSegmentation::runInference( auto inputShapes = getAllInputShapes(methodName); if (inputShapes.empty() || inputShapes[0].empty()) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - "Method '" + methodName + - "' has invalid input tensor shape."); + "Method '" + methodName + "' has invalid input tensor shape."); } modelInputShape_ = inputShapes[0]; @@ -77,51 +72,47 @@ std::vector BaseInstanceSegmentation::runInference( validateThresholds(confidenceThreshold, iouThreshold); - auto forwardResult = - BaseModel::execute(methodName, {buildInputTensor(image)}); + auto forwardResult = BaseModel::execute(methodName, {buildInputTensor(image)}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); validateOutputTensors(forwardResult.get()); - auto instances = collectInstances( - forwardResult.get(), originalSize, modelInputSize, confidenceThreshold, - classIndices, returnMaskAtOriginalResolution); + auto instances = + collectInstances(forwardResult.get(), originalSize, modelInputSize, confidenceThreshold, + classIndices, returnMaskAtOriginalResolution); return finalizeInstances(std::move(instances), iouThreshold, maxInstances); } std::vector BaseInstanceSegmentation::generateFromString( - std::string imageSource, double confidenceThreshold, double iouThreshold, - int32_t maxInstances, std::vector classIndices, - bool returnMaskAtOriginalResolution, std::string methodName) { + std::string imageSource, double confidenceThreshold, double iouThreshold, int32_t maxInstances, + std::vector classIndices, bool returnMaskAtOriginalResolution, + std::string methodName) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Mat imageRGB; cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); - return runInference(imageRGB, confidenceThreshold, iouThreshold, maxInstances, - classIndices, returnMaskAtOriginalResolution, methodName); + return runInference(imageRGB, confidenceThreshold, iouThreshold, maxInstances, classIndices, + returnMaskAtOriginalResolution, methodName); } std::vector BaseInstanceSegmentation::generateFromFrame( - jsi::Runtime &runtime, const jsi::Value &frameData, - double confidenceThreshold, double iouThreshold, int32_t maxInstances, - std::vector classIndices, bool returnMaskAtOriginalResolution, - std::string methodName) { + jsi::Runtime &runtime, const jsi::Value &frameData, double confidenceThreshold, + double iouThreshold, int32_t maxInstances, std::vector classIndices, + bool returnMaskAtOriginalResolution, std::string methodName) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = extractFromFrame(runtime, frameData); cv::Mat rotated = utils::rotateFrameForModel(frame, orient); - auto instances = - runInference(rotated, confidenceThreshold, iouThreshold, maxInstances, - classIndices, returnMaskAtOriginalResolution, methodName); + auto instances = runInference(rotated, confidenceThreshold, iouThreshold, maxInstances, + classIndices, returnMaskAtOriginalResolution, methodName); for (auto &inst : instances) { utils::inverseRotateBbox(inst.bbox, orient, rotated.size()); // Inverse-rotate the mask to match the screen orientation - cv::Mat maskMat(inst.maskHeight, inst.maskWidth, CV_8UC1, - inst.mask->data()); + cv::Mat maskMat(inst.maskHeight, inst.maskWidth, CV_8UC1, inst.mask->data()); cv::Mat invMask = utils::inverseRotateMat(maskMat, orient); - inst.mask = std::make_shared( - invMask.data, static_cast(invMask.total())); + inst.mask = + std::make_shared(invMask.data, static_cast(invMask.total())); inst.maskWidth = invMask.cols; inst.maskHeight = invMask.rows; } @@ -130,30 +121,28 @@ std::vector BaseInstanceSegmentation::generateFromFrame( std::vector BaseInstanceSegmentation::generateFromPixels( JSTensorViewIn tensorView, double confidenceThreshold, double iouThreshold, - int32_t maxInstances, std::vector classIndices, - bool returnMaskAtOriginalResolution, std::string methodName) { + int32_t maxInstances, std::vector classIndices, bool returnMaskAtOriginalResolution, + std::string methodName) { cv::Mat image = extractFromPixels(tensorView); - return runInference(image, confidenceThreshold, iouThreshold, maxInstances, - classIndices, returnMaskAtOriginalResolution, methodName); + return runInference(image, confidenceThreshold, iouThreshold, maxInstances, classIndices, + returnMaskAtOriginalResolution, methodName); } std::tuple -BaseInstanceSegmentation::extractDetectionData(const float *bboxData, - const float *scoresData, +BaseInstanceSegmentation::extractDetectionData(const float *bboxData, const float *scoresData, int32_t index) { - utils::computer_vision::BBox bbox{ - bboxData[index * 4], bboxData[index * 4 + 1], bboxData[index * 4 + 2], - bboxData[index * 4 + 3]}; + utils::computer_vision::BBox bbox{bboxData[index * 4], bboxData[index * 4 + 1], + bboxData[index * 4 + 2], bboxData[index * 4 + 3]}; float score = scoresData[index * 2]; int32_t label = static_cast(scoresData[index * 2 + 1]); return {bbox, score, label}; } -cv::Rect BaseInstanceSegmentation::computeMaskCropRect( - const utils::computer_vision::BBox &bboxModel, cv::Size modelInputSize, - cv::Size maskSize) { +cv::Rect +BaseInstanceSegmentation::computeMaskCropRect(const utils::computer_vision::BBox &bboxModel, + cv::Size modelInputSize, cv::Size maskSize) { float mx1F = bboxModel.p1.x * maskSize.width / modelInputSize.width; float my1F = bboxModel.p1.y * maskSize.height / modelInputSize.height; @@ -163,14 +152,12 @@ cv::Rect BaseInstanceSegmentation::computeMaskCropRect( int32_t mx1 = std::max(0, static_cast(std::floor(mx1F))); int32_t my1 = std::max(0, static_cast(std::floor(my1F))); int32_t mx2 = std::min(maskSize.width, static_cast(std::ceil(mx2F))); - int32_t my2 = - std::min(maskSize.height, static_cast(std::ceil(my2F))); + int32_t my2 = std::min(maskSize.height, static_cast(std::ceil(my2F))); return {mx1, my1, mx2 - mx1, my2 - my1}; } -cv::Rect BaseInstanceSegmentation::addPaddingToRect(const cv::Rect &rect, - cv::Size maskSize) { +cv::Rect BaseInstanceSegmentation::addPaddingToRect(const cv::Rect &rect, cv::Size maskSize) { int32_t x1 = std::max(0, rect.x - 1); int32_t y1 = std::max(0, rect.y - 1); int32_t x2 = std::min(maskSize.width, rect.x + rect.width + 1); @@ -180,15 +167,14 @@ cv::Rect BaseInstanceSegmentation::addPaddingToRect(const cv::Rect &rect, } cv::Mat BaseInstanceSegmentation::warpToOriginalResolution( - const cv::Mat &probMat, const cv::Rect &maskRect, cv::Size originalSize, - cv::Size maskSize, const utils::computer_vision::BBox &bboxOriginal) { + const cv::Mat &probMat, const cv::Rect &maskRect, cv::Size originalSize, cv::Size maskSize, + const utils::computer_vision::BBox &bboxOriginal) { float scaleX = static_cast(originalSize.width) / maskSize.width; float scaleY = static_cast(originalSize.height) / maskSize.height; - cv::Mat M = (cv::Mat_(2, 3) << scaleX, 0, - (maskRect.x * scaleX - bboxOriginal.p1.x), 0, scaleY, - (maskRect.y * scaleY - bboxOriginal.p1.y)); + cv::Mat M = (cv::Mat_(2, 3) << scaleX, 0, (maskRect.x * scaleX - bboxOriginal.p1.x), 0, + scaleY, (maskRect.y * scaleY - bboxOriginal.p1.y)); cv::Size bboxSize(static_cast(std::round(bboxOriginal.width())), static_cast(std::round(bboxOriginal.height()))); @@ -220,8 +206,7 @@ cv::Mat BaseInstanceSegmentation::processMaskFromLogits( cv::Mat probMat = image_processing::applySigmoid(cropped); if (warpToOriginal) { - probMat = warpToOriginalResolution(probMat, cropRect, originalSize, - maskSize, bboxOriginal); + probMat = warpToOriginalResolution(probMat, cropRect, originalSize, maskSize, bboxOriginal); } return thresholdToBinary(probMat); } @@ -229,10 +214,9 @@ cv::Mat BaseInstanceSegmentation::processMaskFromLogits( void BaseInstanceSegmentation::validateThresholds(double confidenceThreshold, double iouThreshold) const { if (confidenceThreshold < 0 || confidenceThreshold > 1) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidConfig, - "Confidence threshold must be greater or equal to 0 " - "and less than or equal to 1."); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Confidence threshold must be greater or equal to 0 " + "and less than or equal to 1."); } if (iouThreshold < 0 || iouThreshold > 1) { @@ -242,8 +226,7 @@ void BaseInstanceSegmentation::validateThresholds(double confidenceThreshold, } } -void BaseInstanceSegmentation::validateOutputTensors( - const std::vector &tensors) const { +void BaseInstanceSegmentation::validateOutputTensors(const std::vector &tensors) const { if (tensors.size() != 3) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput, "Expected 3 output tensors ([1,N,4] + [1,N,2] + " @@ -252,8 +235,8 @@ void BaseInstanceSegmentation::validateOutputTensors( } } -std::set BaseInstanceSegmentation::prepareAllowedClasses( - const std::vector &classIndices) const { +std::set +BaseInstanceSegmentation::prepareAllowedClasses(const std::vector &classIndices) const { std::set allowedClasses; if (!classIndices.empty()) { allowedClasses.insert(classIndices.begin(), classIndices.end()); @@ -261,13 +244,11 @@ std::set BaseInstanceSegmentation::prepareAllowedClasses( return allowedClasses; } -void BaseInstanceSegmentation::ensureMethodLoaded( - const std::string &methodName) { +void BaseInstanceSegmentation::ensureMethodLoaded(const std::string &methodName) { if (methodName.empty()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidConfig, - "Method name cannot be empty. Use 'forward' for single-method models " - "or 'forward_{inputSize}' for multi-method models."); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Method name cannot be empty. Use 'forward' for single-method models " + "or 'forward_{inputSize}' for multi-method models."); } if (currentlyLoadedMethod_ == methodName) { @@ -276,8 +257,7 @@ void BaseInstanceSegmentation::ensureMethodLoaded( if (!module_) { throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - "Model not loaded. Cannot load method '" + - methodName + "'."); + "Model not loaded. Cannot load method '" + methodName + "'."); } if (!currentlyLoadedMethod_.empty()) { @@ -286,21 +266,19 @@ void BaseInstanceSegmentation::ensureMethodLoaded( auto loadResult = module_->load_method(methodName); if (loadResult != executorch::runtime::Error::Ok) { - throw RnExecutorchError( - loadResult, "Failed to load method '" + methodName + - "'. Ensure the method exists in the exported model."); + throw RnExecutorchError(loadResult, "Failed to load method '" + methodName + + "'. Ensure the method exists in the exported model."); } currentlyLoadedMethod_ = methodName; } -std::vector BaseInstanceSegmentation::finalizeInstances( - std::vector instances, double iouThreshold, - int32_t maxInstances) const { +std::vector +BaseInstanceSegmentation::finalizeInstances(std::vector instances, + double iouThreshold, int32_t maxInstances) const { if (applyNMS_) { - instances = - utils::computer_vision::nonMaxSuppression(instances, iouThreshold); + instances = utils::computer_vision::nonMaxSuppression(instances, iouThreshold); } if (std::cmp_greater(instances.size(), maxInstances)) { @@ -311,15 +289,12 @@ std::vector BaseInstanceSegmentation::finalizeInstances( } std::vector BaseInstanceSegmentation::collectInstances( - const std::vector &tensors, cv::Size originalSize, - cv::Size modelInputSize, double confidenceThreshold, - const std::vector &classIndices, + const std::vector &tensors, cv::Size originalSize, cv::Size modelInputSize, + double confidenceThreshold, const std::vector &classIndices, bool returnMaskAtOriginalResolution) { - float widthRatio = - static_cast(originalSize.width) / modelInputSize.width; - float heightRatio = - static_cast(originalSize.height) / modelInputSize.height; + float widthRatio = static_cast(originalSize.width) / modelInputSize.width; + float heightRatio = static_cast(originalSize.height) / modelInputSize.height; auto allowedClasses = prepareAllowedClasses(classIndices); // CONTRACT @@ -335,38 +310,32 @@ std::vector BaseInstanceSegmentation::collectInstances( const float *scoresData = scoresTensor.const_data_ptr(); const float *maskData = maskTensor.const_data_ptr(); - auto isValidDetection = - [&allowedClasses, &confidenceThreshold](float score, int32_t labelIdx) { - return score >= confidenceThreshold && - (allowedClasses.empty() || allowedClasses.count(labelIdx) != 0); - }; + auto isValidDetection = [&allowedClasses, &confidenceThreshold](float score, int32_t labelIdx) { + return score >= confidenceThreshold && + (allowedClasses.empty() || allowedClasses.count(labelIdx) != 0); + }; std::vector instances; for (int32_t i = 0; i < numInstances; ++i) { - auto [bboxModel, score, labelIdx] = - extractDetectionData(bboxData, scoresData, i); + auto [bboxModel, score, labelIdx] = extractDetectionData(bboxData, scoresData, i); if (!isValidDetection(score, labelIdx)) { continue; } - utils::computer_vision::BBox bboxOriginal = - bboxModel.scale(widthRatio, heightRatio); + utils::computer_vision::BBox bboxOriginal = bboxModel.scale(widthRatio, heightRatio); if (!bboxOriginal.isValid()) { continue; } - cv::Mat logitsMat(maskH, maskW, CV_32FC1, - const_cast(maskData + (i * maskH * maskW))); + cv::Mat logitsMat(maskH, maskW, CV_32FC1, const_cast(maskData + (i * maskH * maskW))); - cv::Mat binaryMask = processMaskFromLogits( - logitsMat, bboxModel, bboxOriginal, modelInputSize, originalSize, - returnMaskAtOriginalResolution); + cv::Mat binaryMask = processMaskFromLogits(logitsMat, bboxModel, bboxOriginal, modelInputSize, + originalSize, returnMaskAtOriginalResolution); instances.emplace_back(bboxOriginal, - std::make_shared( - binaryMask.data, binaryMask.total()), + std::make_shared(binaryMask.data, binaryMask.total()), binaryMask.cols, binaryMask.rows, labelIdx, score); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h index 341d0f2235..7d428b64ea 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h @@ -18,83 +18,73 @@ using executorch::runtime::EValue; class BaseInstanceSegmentation : public VisionModel { public: - BaseInstanceSegmentation(const std::string &modelSource, - std::vector normMean, + BaseInstanceSegmentation(const std::string &modelSource, std::vector normMean, std::vector normStd, bool applyNMS, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector - generateFromString(std::string imageSource, double confidenceThreshold, - double iouThreshold, int32_t maxInstances, - std::vector classIndices, - bool returnMaskAtOriginalResolution, - std::string methodName); + generateFromString(std::string imageSource, double confidenceThreshold, double iouThreshold, + int32_t maxInstances, std::vector classIndices, + bool returnMaskAtOriginalResolution, std::string methodName); [[nodiscard("Registered non-void function")]] std::vector - generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, - double confidenceThreshold, double iouThreshold, - int32_t maxInstances, std::vector classIndices, - bool returnMaskAtOriginalResolution, - std::string methodName); + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double confidenceThreshold, + double iouThreshold, int32_t maxInstances, std::vector classIndices, + bool returnMaskAtOriginalResolution, std::string methodName); [[nodiscard("Registered non-void function")]] std::vector - generateFromPixels(JSTensorViewIn tensorView, double confidenceThreshold, - double iouThreshold, int32_t maxInstances, - std::vector classIndices, - bool returnMaskAtOriginalResolution, - std::string methodName); + generateFromPixels(JSTensorViewIn tensorView, double confidenceThreshold, double iouThreshold, + int32_t maxInstances, std::vector classIndices, + bool returnMaskAtOriginalResolution, std::string methodName); protected: cv::Size modelInputSize() const override; private: - std::vector runInference( - const cv::Mat &image, double confidenceThreshold, double iouThreshold, - int32_t maxInstances, const std::vector &classIndices, - bool returnMaskAtOriginalResolution, const std::string &methodName); + std::vector runInference(const cv::Mat &image, double confidenceThreshold, + double iouThreshold, int32_t maxInstances, + const std::vector &classIndices, + bool returnMaskAtOriginalResolution, + const std::string &methodName); TensorPtr buildInputTensor(const cv::Mat &image); - std::vector - collectInstances(const std::vector &tensors, cv::Size originalSize, - cv::Size modelInputSize, double confidenceThreshold, - const std::vector &classIndices, - bool returnMaskAtOriginalResolution); + std::vector collectInstances(const std::vector &tensors, + cv::Size originalSize, cv::Size modelInputSize, + double confidenceThreshold, + const std::vector &classIndices, + bool returnMaskAtOriginalResolution); - void validateThresholds(double confidenceThreshold, - double iouThreshold) const; + void validateThresholds(double confidenceThreshold, double iouThreshold) const; void validateOutputTensors(const std::vector &tensors) const; - std::set - prepareAllowedClasses(const std::vector &classIndices) const; + std::set prepareAllowedClasses(const std::vector &classIndices) const; // Model loading and input helpers void ensureMethodLoaded(const std::string &methodName); std::tuple - extractDetectionData(const float *bboxData, const float *scoresData, - int32_t index); + extractDetectionData(const float *bboxData, const float *scoresData, int32_t index); cv::Rect computeMaskCropRect(const utils::computer_vision::BBox &bboxModel, cv::Size modelInputSize, cv::Size maskSize); cv::Rect addPaddingToRect(const cv::Rect &rect, cv::Size maskSize); - cv::Mat - warpToOriginalResolution(const cv::Mat &probMat, const cv::Rect &maskRect, - cv::Size originalSize, cv::Size maskSize, - const utils::computer_vision::BBox &bboxOriginal); + cv::Mat warpToOriginalResolution(const cv::Mat &probMat, const cv::Rect &maskRect, + cv::Size originalSize, cv::Size maskSize, + const utils::computer_vision::BBox &bboxOriginal); cv::Mat thresholdToBinary(const cv::Mat &probMat); - std::vector - finalizeInstances(std::vector instances, double iouThreshold, - int32_t maxInstances) const; + std::vector finalizeInstances(std::vector instances, + double iouThreshold, int32_t maxInstances) const; - cv::Mat processMaskFromLogits( - const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel, - const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize, - cv::Size originalSize, bool warpToOriginal); + cv::Mat processMaskFromLogits(const cv::Mat &logitsMat, + const utils::computer_vision::BBox &bboxModel, + const utils::computer_vision::BBox &bboxOriginal, + cv::Size modelInputSize, cv::Size originalSize, + bool warpToOriginal); std::optional normMean_; std::optional normStd_; @@ -103,7 +93,7 @@ class BaseInstanceSegmentation : public VisionModel { }; } // namespace models::instance_segmentation -REGISTER_CONSTRUCTOR(models::instance_segmentation::BaseInstanceSegmentation, - std::string, std::vector, std::vector, bool, +REGISTER_CONSTRUCTOR(models::instance_segmentation::BaseInstanceSegmentation, std::string, + std::vector, std::vector, bool, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h index 9006688ce1..cf35c38215 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/Types.h @@ -16,11 +16,10 @@ namespace rnexecutorch::models::instance_segmentation::types { struct Instance { Instance() = default; - Instance(utils::computer_vision::BBox bbox, - std::shared_ptr mask, int32_t maskWidth, - int32_t maskHeight, int32_t classIndex, float score) - : bbox(bbox), mask(std::move(mask)), maskWidth(maskWidth), - maskHeight(maskHeight), classIndex(classIndex), score(score) {} + Instance(utils::computer_vision::BBox bbox, std::shared_ptr mask, + int32_t maskWidth, int32_t maskHeight, int32_t classIndex, float score) + : bbox(bbox), mask(std::move(mask)), maskWidth(maskWidth), maskHeight(maskHeight), + classIndex(classIndex), score(score) {} utils::computer_vision::BBox bbox; std::shared_ptr mask; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp index 924bba9f99..289297245f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.cpp @@ -19,25 +19,21 @@ using executorch::extension::module::Module; using executorch::runtime::Error; LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, - std::vector capabilities, - std::shared_ptr callInvoker) + std::vector capabilities, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker, Module::LoadMode::Mmap) { if (capabilities.empty()) { - runner_ = - std::make_unique(std::move(module_), tokenizerSource); + runner_ = std::make_unique(std::move(module_), tokenizerSource); } else { std::map> encoders; for (const auto &cap : capabilities) { if (cap == "vision") { - encoders[llm::MultimodalType::Image] = - std::make_unique(*module_); + encoders[llm::MultimodalType::Image] = std::make_unique(*module_); } else if (cap == "audio") { - encoders[llm::MultimodalType::Audio] = - std::make_unique(*module_); + encoders[llm::MultimodalType::Audio] = std::make_unique(*module_); } } - runner_ = std::make_unique( - std::move(module_), tokenizerSource, std::move(encoders)); + runner_ = std::make_unique(std::move(module_), tokenizerSource, + std::move(encoders)); } auto loadResult = runner_->load(); @@ -53,11 +49,9 @@ LLM::LLM(const std::string &modelSource, const std::string &tokenizerSource, memorySizeLowerBound = fs::file_size(fs::path(tokenizerSource)); } -std::string LLM::generate(std::string input, - std::shared_ptr callback) { +std::string LLM::generate(std::string input, std::shared_ptr callback) { if (!runner_ || !runner_->is_loaded()) { - throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - "Runner is not loaded"); + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Runner is not loaded"); } std::string output; auto nativeCallback = [this, callback, &output](const std::string &token) { @@ -77,22 +71,18 @@ std::string LLM::generate(std::string input, return output; } -std::string LLM::generateMultimodal(std::string prompt, - std::shared_ptr callback, +std::string LLM::generateMultimodal(std::string prompt, std::shared_ptr callback, MultimodalInputs mutlimodalInputs) { if (!runner_ || !runner_->is_loaded()) { - throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - "Runner is not loaded"); + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Runner is not loaded"); } if (!runner_->is_multimodal()) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "This model does not support multimodal input."); } - if (!mutlimodalInputs.images.has_value() && - !mutlimodalInputs.audios.has_value()) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "At least one of imageToken/audioToken must be non-empty"); + if (!mutlimodalInputs.images.has_value() && !mutlimodalInputs.audios.has_value()) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "At least one of imageToken/audioToken must be non-empty"); } // Scan the prompt once, splitting at the earliest placeholder at each step @@ -110,8 +100,8 @@ std::string LLM::generateMultimodal(std::string prompt, inputs.push_back(llm::make_text_input(prompt.substr(pos))); break; } - const bool imageFirst = imgAt != std::string::npos && - (audAt == std::string::npos || imgAt < audAt); + const bool imageFirst = + imgAt != std::string::npos && (audAt == std::string::npos || imgAt < audAt); size_t at = imageFirst ? imgAt : audAt; if (at > pos) { inputs.push_back(llm::make_text_input(prompt.substr(pos, at - pos))); @@ -120,8 +110,7 @@ std::string LLM::generateMultimodal(std::string prompt, auto &images = mutlimodalInputs.images.value(); if (imageIdx >= images.paths.size()) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "More '" + images.token + - "' placeholders than image paths"); + "More '" + images.token + "' placeholders than image paths"); } inputs.push_back(llm::make_image_input(images.paths[imageIdx++])); pos = at + images.token.size(); @@ -129,11 +118,9 @@ std::string LLM::generateMultimodal(std::string prompt, auto &audios = mutlimodalInputs.audios.value(); if (audioIdx >= audios.waveforms.size()) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "More '" + audios.token + - "' placeholders than audio waveforms"); + "More '" + audios.token + "' placeholders than audio waveforms"); } - inputs.push_back( - llm::make_audio_input(std::move(audios.waveforms[audioIdx++]))); + inputs.push_back(llm::make_audio_input(std::move(audios.waveforms[audioIdx++]))); pos = at + audios.token.size(); } } @@ -141,13 +128,11 @@ std::string LLM::generateMultimodal(std::string prompt, imageIdx < mutlimodalInputs.images.value().paths.size()) || (mutlimodalInputs.audios.has_value() && audioIdx < mutlimodalInputs.audios.value().waveforms.size())) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "More image/audio paths provided than placeholders in prompt"); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "More image/audio paths provided than placeholders in prompt"); } if (inputs.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "No inputs to generate from"); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "No inputs to generate from"); } std::string output; @@ -204,16 +189,13 @@ int32_t LLM::getVisualTokenCount() const { int32_t LLM::countTextTokens(std::string text) const { if (!runner_ || !runner_->is_loaded()) { - throw RnExecutorchError( - RnExecutorchErrorCode::ModuleNotLoaded, - "Can't count tokens from a model that's not loaded"); + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, + "Can't count tokens from a model that's not loaded"); } return runner_->count_text_tokens(text); } -size_t LLM::getMemoryLowerBound() const noexcept { - return memorySizeLowerBound; -} +size_t LLM::getMemoryLowerBound() const noexcept { return memorySizeLowerBound; } void LLM::setCountInterval(size_t countInterval) { if (!runner_ || !runner_->is_loaded()) { @@ -289,9 +271,8 @@ void LLM::setRepetitionPenalty(float repetitionPenalty) { int32_t LLM::getMaxContextLength() const { if (!runner_ || !runner_->is_loaded()) { - throw RnExecutorchError( - RnExecutorchErrorCode::ModuleNotLoaded, - "Can't get context length from a model that's not loaded"); + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, + "Can't get context length from a model that's not loaded"); } return runner_->get_max_context_length(); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h index 4b7087351b..de7e86728a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/llm/LLM.h @@ -16,16 +16,13 @@ using namespace facebook; class LLM : public BaseModel { public: - explicit LLM(const std::string &modelSource, - const std::string &tokenizerSource, + explicit LLM(const std::string &modelSource, const std::string &tokenizerSource, std::vector capabilities, std::shared_ptr callInvoker); - std::string generate(std::string prompt, - std::shared_ptr callback); + std::string generate(std::string prompt, std::shared_ptr callback); - std::string generateMultimodal(std::string prompt, - std::shared_ptr callback, + std::string generateMultimodal(std::string prompt, std::shared_ptr callback, MultimodalInputs mutlimodalInputs = {}); void interrupt(); @@ -49,7 +46,6 @@ class LLM : public BaseModel { }; } // namespace models::llm -REGISTER_CONSTRUCTOR(models::llm::LLM, std::string, std::string, - std::vector, +REGISTER_CONSTRUCTOR(models::llm::LLM, std::string, std::string, std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Constants.h index 95d58d596d..d01e609b3d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Constants.h @@ -2,4 +2,4 @@ namespace rnexecutorch::models::object_detection::constants { inline constexpr float IOU_THRESHOLD = 0.55f; -} // namespace rnexecutorch::models::object_detection::constants \ No newline at end of file +} // namespace rnexecutorch::models::object_detection::constants diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 24c4e1083a..8a219cf423 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -14,23 +14,19 @@ namespace rnexecutorch::models::object_detection { -ObjectDetection::ObjectDetection( - const std::string &modelSource, std::vector normMean, - std::vector normStd, std::vector labelNames, - std::shared_ptr callInvoker) - : VisionModel(modelSource, callInvoker), - labelNames_(std::move(labelNames)) { +ObjectDetection::ObjectDetection(const std::string &modelSource, std::vector normMean, + std::vector normStd, std::vector labelNames, + std::shared_ptr callInvoker) + : VisionModel(modelSource, callInvoker), labelNames_(std::move(labelNames)) { if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { - log(LOG_LEVEL::Warn, - "normMean must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normMean must have 3 elements — ignoring provided value."); } if (normStd.size() == 3) { normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); } else if (!normStd.empty()) { - log(LOG_LEVEL::Warn, - "normStd must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normStd must have 3 elements — ignoring provided value."); } } @@ -51,30 +47,27 @@ cv::Size ObjectDetection::modelInputSize() const { void ObjectDetection::ensureMethodLoaded(const std::string &methodName) { if (methodName.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - "methodName cannot be empty"); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "methodName cannot be empty"); } if (currentlyLoadedMethod_ == methodName) { return; } if (!module_) { - throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - "Model module is not loaded"); + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "Model module is not loaded"); } if (!currentlyLoadedMethod_.empty()) { module_->unload_method(currentlyLoadedMethod_); } auto loadResult = module_->load_method(methodName); if (loadResult != executorch::runtime::Error::Ok) { - throw RnExecutorchError( - loadResult, "Failed to load method '" + methodName + - "'. Ensure the method exists in the exported model."); + throw RnExecutorchError(loadResult, "Failed to load method '" + methodName + + "'. Ensure the method exists in the exported model."); } currentlyLoadedMethod_ = methodName; } -std::set ObjectDetection::prepareAllowedClasses( - const std::vector &classIndices) const { +std::set +ObjectDetection::prepareAllowedClasses(const std::vector &classIndices) const { std::set allowedClasses; if (!classIndices.empty()) { allowedClasses.insert(classIndices.begin(), classIndices.end()); @@ -83,33 +76,28 @@ std::set ObjectDetection::prepareAllowedClasses( } std::vector -ObjectDetection::postprocess(const std::vector &tensors, - cv::Size originalSize, double detectionThreshold, - double iouThreshold, +ObjectDetection::postprocess(const std::vector &tensors, cv::Size originalSize, + double detectionThreshold, double iouThreshold, const std::vector &classIndices) { const cv::Size inputSize = modelInputSize(); float widthRatio = static_cast(originalSize.width) / inputSize.width; - float heightRatio = - static_cast(originalSize.height) / inputSize.height; + float heightRatio = static_cast(originalSize.height) / inputSize.height; // Prepare allowed classes set for filtering auto allowedClasses = prepareAllowedClasses(classIndices); std::vector detections; auto bboxTensor = tensors.at(0).toTensor(); - std::span bboxes( - static_cast(bboxTensor.const_data_ptr()), - bboxTensor.numel()); + std::span bboxes(static_cast(bboxTensor.const_data_ptr()), + bboxTensor.numel()); auto scoreTensor = tensors.at(1).toTensor(); - std::span scores( - static_cast(scoreTensor.const_data_ptr()), - scoreTensor.numel()); + std::span scores(static_cast(scoreTensor.const_data_ptr()), + scoreTensor.numel()); auto labelTensor = tensors.at(2).toTensor(); - std::span labels( - static_cast(labelTensor.const_data_ptr()), - labelTensor.numel()); + std::span labels(static_cast(labelTensor.const_data_ptr()), + labelTensor.numel()); for (std::size_t i = 0; i < scores.size(); ++i) { if (scores[i] < detectionThreshold) { @@ -119,8 +107,7 @@ ObjectDetection::postprocess(const std::vector &tensors, auto labelIdx = static_cast(labels[i]); // Filter by class if classesOfInterest is specified - if (!allowedClasses.empty() && - allowedClasses.find(labelIdx) == allowedClasses.end()) { + if (!allowedClasses.empty() && allowedClasses.find(labelIdx) == allowedClasses.end()) { continue; } @@ -130,22 +117,22 @@ ObjectDetection::postprocess(const std::vector &tensors, float y2 = bboxes[i * 4 + 3] * heightRatio; if (std::cmp_greater_equal(labelIdx, labelNames_.size())) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidConfig, - "Model output class index " + std::to_string(labelIdx) + - " exceeds labelNames size " + std::to_string(labelNames_.size()) + - ". Ensure the labelMap covers all model output classes."); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Model output class index " + std::to_string(labelIdx) + + " exceeds labelNames size " + std::to_string(labelNames_.size()) + + ". Ensure the labelMap covers all model output classes."); } - detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2}, - labelNames_[labelIdx], labelIdx, scores[i]); + detections.emplace_back(utils::computer_vision::BBox{x1, y1, x2, y2}, labelNames_[labelIdx], + labelIdx, scores[i]); } return utils::computer_vision::nonMaxSuppression(detections, iouThreshold); } -std::vector ObjectDetection::runInference( - cv::Mat image, double detectionThreshold, double iouThreshold, - const std::vector &classIndices, const std::string &methodName) { +std::vector +ObjectDetection::runInference(cv::Mat image, double detectionThreshold, double iouThreshold, + const std::vector &classIndices, + const std::string &methodName) { if (detectionThreshold < 0.0 || detectionThreshold > 1.0) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "detectionThreshold must be in range [0, 1]"); @@ -166,53 +153,50 @@ std::vector ObjectDetection::runInference( auto inputShapes = getAllInputShapes(methodName); if (inputShapes.empty() || inputShapes[0].size() < 2) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - "Could not determine input shape for method: " + - methodName); + "Could not determine input shape for method: " + methodName); } modelInputShape_ = inputShapes[0]; cv::Mat preprocessed = preprocess(image); - auto inputTensor = - (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix( - modelInputShape_, preprocessed, *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(modelInputShape_, - preprocessed); + auto inputTensor = (normMean_ && normStd_) + ? image_processing::getTensorFromMatrix(modelInputShape_, preprocessed, + *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto executeResult = execute(methodName, {inputTensor}); if (!executeResult.ok()) { - throw RnExecutorchError(executeResult.error(), - "The model's " + methodName + - " method did not succeed. " - "Ensure the model input is correct."); + throw RnExecutorchError(executeResult.error(), "The model's " + methodName + + " method did not succeed. " + "Ensure the model input is correct."); } - return postprocess(executeResult.get(), originalSize, detectionThreshold, - iouThreshold, classIndices); + return postprocess(executeResult.get(), originalSize, detectionThreshold, iouThreshold, + classIndices); } -std::vector ObjectDetection::generateFromString( - std::string imageSource, double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName) { +std::vector ObjectDetection::generateFromString(std::string imageSource, + double detectionThreshold, + double iouThreshold, + std::vector classIndices, + std::string methodName) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Mat imageRGB; cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); - return runInference(imageRGB, detectionThreshold, iouThreshold, classIndices, - methodName); + return runInference(imageRGB, detectionThreshold, iouThreshold, classIndices, methodName); } -std::vector ObjectDetection::generateFromFrame( - jsi::Runtime &runtime, const jsi::Value &frameData, - double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName) { +std::vector +ObjectDetection::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, + double detectionThreshold, double iouThreshold, + std::vector classIndices, std::string methodName) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = extractFromFrame(runtime, frameData); cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient); - auto detections = runInference(rotated, detectionThreshold, iouThreshold, - classIndices, methodName); + auto detections = + runInference(rotated, detectionThreshold, iouThreshold, classIndices, methodName); for (auto &det : detections) { ::rnexecutorch::utils::inverseRotateBbox(det.bbox, orient, rotated.size()); @@ -220,12 +204,13 @@ std::vector ObjectDetection::generateFromFrame( return detections; } -std::vector ObjectDetection::generateFromPixels( - JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName) { +std::vector ObjectDetection::generateFromPixels(JSTensorViewIn pixelData, + double detectionThreshold, + double iouThreshold, + std::vector classIndices, + std::string methodName) { cv::Mat image = extractFromPixels(pixelData); - return runInference(image, detectionThreshold, iouThreshold, classIndices, - methodName); + return runInference(image, detectionThreshold, iouThreshold, classIndices, methodName); } } // namespace rnexecutorch::models::object_detection diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index 6e3c01356e..d18e0952cc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -43,8 +43,7 @@ class ObjectDetection : public VisionModel { * is incompatible. */ ObjectDetection(const std::string &modelSource, std::vector normMean, - std::vector normStd, - std::vector labelNames, + std::vector normStd, std::vector labelNames, std::shared_ptr callInvoker); /** @@ -73,17 +72,14 @@ class ObjectDetection : public VisionModel { * fails. */ [[nodiscard("Registered non-void function")]] std::vector - generateFromString(std::string imageSource, double detectionThreshold, - double iouThreshold, std::vector classIndices, - std::string methodName); + generateFromString(std::string imageSource, double detectionThreshold, double iouThreshold, + std::vector classIndices, std::string methodName); [[nodiscard("Registered non-void function")]] std::vector - generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, - double detectionThreshold, double iouThreshold, - std::vector classIndices, std::string methodName); + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold, + double iouThreshold, std::vector classIndices, std::string methodName); [[nodiscard("Registered non-void function")]] std::vector - generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold, - double iouThreshold, std::vector classIndices, - std::string methodName); + generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold, double iouThreshold, + std::vector classIndices, std::string methodName); protected: /** @@ -96,10 +92,10 @@ class ObjectDetection : public VisionModel { */ cv::Size modelInputSize() const override; - std::vector - runInference(cv::Mat image, double detectionThreshold, double iouThreshold, - const std::vector &classIndices, - const std::string &methodName); + std::vector runInference(cv::Mat image, double detectionThreshold, + double iouThreshold, + const std::vector &classIndices, + const std::string &methodName); private: /** @@ -120,10 +116,10 @@ class ObjectDetection : public VisionModel { * @throws RnExecutorchError if the model outputs a class index that exceeds * the size of @ref labelNames_. */ - std::vector - postprocess(const std::vector &tensors, cv::Size originalSize, - double detectionThreshold, double iouThreshold, - const std::vector &classIndices); + std::vector postprocess(const std::vector &tensors, + cv::Size originalSize, double detectionThreshold, + double iouThreshold, + const std::vector &classIndices); /** * @brief Ensures the specified method is loaded, unloading any previous @@ -141,8 +137,7 @@ class ObjectDetection : public VisionModel { * @param classIndices Vector of class indices to allow. * @return A set containing the allowed class indices. */ - std::set - prepareAllowedClasses(const std::vector &classIndices) const; + std::set prepareAllowedClasses(const std::vector &classIndices) const; /// Optional per-channel mean for input normalisation (set in constructor). std::optional normMean_; @@ -158,8 +153,7 @@ class ObjectDetection : public VisionModel { }; } // namespace models::object_detection -REGISTER_CONSTRUCTOR(models::object_detection::ObjectDetection, std::string, - std::vector, std::vector, - std::vector, +REGISTER_CONSTRUCTOR(models::object_detection::ObjectDetection, std::string, std::vector, + std::vector, std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h index 1652516e89..8f3f8206f5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/Types.h @@ -8,10 +8,8 @@ namespace rnexecutorch::models::object_detection::types { struct Detection { Detection() = default; - Detection(utils::computer_vision::BBox bbox, std::string label, - int32_t classIndex, float score) - : bbox(bbox), label(std::move(label)), classIndex(classIndex), - score(score) {} + Detection(utils::computer_vision::BBox bbox, std::string label, int32_t classIndex, float score) + : bbox(bbox), label(std::move(label)), classIndex(classIndex), score(score) {} utils::computer_vision::BBox bbox; std::string label; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.cpp index 6a1913ae7d..fe91271ce2 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.cpp @@ -5,12 +5,10 @@ namespace rnexecutorch::models::ocr { CTCLabelConverter::CTCLabelConverter(const std::string &characters) - : ignoreIdx(0), - character({"[blank]"}) // blank character is ignored character (index 0). + : ignoreIdx(0), character({"[blank]"}) // blank character is ignored character (index 0). { if (characters.empty()) { - throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, - "Character set cannot be empty"); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, "Character set cannot be empty"); } for (size_t i = 0; i < characters.length();) { size_t char_len = 0; @@ -38,9 +36,8 @@ CTCLabelConverter::CTCLabelConverter(const std::string &characters) } } -std::vector -CTCLabelConverter::decodeGreedy(const std::vector &textIndex, - size_t length) { +std::vector CTCLabelConverter::decodeGreedy(const std::vector &textIndex, + size_t length) { /* The current strategy used for decoding is greedy approach which iterates through the list of indices and process @@ -67,13 +64,11 @@ CTCLabelConverter::decodeGreedy(const std::vector &textIndex, if (!subArray.empty()) { std::optional lastChar; for (int32_t currentChar : subArray) { - bool isRepeated = - lastChar.has_value() && lastChar.value() == currentChar; + bool isRepeated = lastChar.has_value() && lastChar.value() == currentChar; bool isIgnored = currentChar == ignoreIdx; lastChar = currentChar; - if (currentChar >= 0 && - currentChar < static_cast(character.size()) && + if (currentChar >= 0 && currentChar < static_cast(character.size()) && !isRepeated && !isIgnored) { text += character[currentChar]; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.h index 2b34847e50..8df2cebe55 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.h @@ -19,8 +19,7 @@ class CTCLabelConverter final { public: explicit CTCLabelConverter(const std::string &characters); - std::vector decodeGreedy(const std::vector &textIndex, - size_t length); + std::vector decodeGreedy(const std::vector &textIndex, size_t length); private: int32_t ignoreIdx; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h index 9b96f17615..038619ad04 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h @@ -24,8 +24,7 @@ inline constexpr int32_t kLargeRecognizerWidth = 512; inline constexpr int32_t kMediumRecognizerWidth = 256; inline constexpr int32_t kSmallRecognizerWidth = 128; inline constexpr int32_t kSmallVerticalRecognizerWidth = 64; -inline constexpr int32_t kMaxWidth = - kLargeRecognizerWidth + (kLargeRecognizerWidth * 0.15); +inline constexpr int32_t kMaxWidth = kLargeRecognizerWidth + (kLargeRecognizerWidth * 0.15); inline constexpr int32_t kSingleCharacterMinSize = 70; inline constexpr int32_t kRecognizerImageSize = 1280; inline constexpr int32_t kVerticalLineThreshold = 20; @@ -35,8 +34,8 @@ inline constexpr int32_t kLargeDetectorWidth = 1280; inline constexpr std::array kDetectorInputWidths = { kSmallDetectorWidth, kMediumDetectorWidth, kLargeDetectorWidth}; inline constexpr std::array kRecognizerInputWidths = { - kSmallVerticalRecognizerWidth, kSmallRecognizerWidth, - kMediumRecognizerWidth, kLargeRecognizerWidth}; + kSmallVerticalRecognizerWidth, kSmallRecognizerWidth, kMediumRecognizerWidth, + kLargeRecognizerWidth}; /* Mean and variance values for image normalization were used in EASYOCR pipeline diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp index b5111d3f16..134c1e9bd5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp @@ -8,26 +8,22 @@ #include namespace rnexecutorch::models::ocr { -Detector::Detector(const std::string &modelSource, - std::shared_ptr callInvoker) +Detector::Detector(const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) { for (auto input_size : constants::kDetectorInputWidths) { std::string methodName = "forward_" + std::to_string(input_size); auto inputShapes = getAllInputShapes(methodName); if (inputShapes[0].size() < 2) { - std::string errorMessage = - "Unexpected detector model input size for method: " + methodName + - "expected at least 2 dimensions but got: ." + - std::to_string(inputShapes[0].size()); - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - errorMessage); + std::string errorMessage = "Unexpected detector model input size for method: " + methodName + + "expected at least 2 dimensions but got: ." + + std::to_string(inputShapes[0].size()); + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } } } -std::vector Detector::generate(const cv::Mat &inputImage, - int32_t inputWidth) { +std::vector Detector::generate(const cv::Mat &inputImage, int32_t inputWidth) { /* Detector as an input accepts tensor with a shape of [1, 3, H, H]. where H is a constant for model. In our supported model it is currently @@ -36,19 +32,17 @@ std::vector Detector::generate(const cv::Mat &inputImage, original aspect ratio and the missing parts are filled with padding. */ - utils::validateInputWidth(inputWidth, constants::kDetectorInputWidths, - "Detector"); + utils::validateInputWidth(inputWidth, constants::kDetectorInputWidths, "Detector"); std::string methodName = "forward_" + std::to_string(inputWidth); auto inputShapes = getAllInputShapes(methodName); cv::Size modelInputSize = calculateModelImageSize(inputWidth); - cv::Mat resizedInputImage = - image_processing::resizePadded(inputImage, modelInputSize); - TensorPtr inputTensor = image_processing::getTensorFromMatrix( - inputShapes[0], resizedInputImage, constants::kNormalizationMean, - constants::kNormalizationVariance); + cv::Mat resizedInputImage = image_processing::resizePadded(inputImage, modelInputSize); + TensorPtr inputTensor = image_processing::getTensorFromMatrix(inputShapes[0], resizedInputImage, + constants::kNormalizationMean, + constants::kNormalizationVariance); auto forwardResult = BaseModel::execute(methodName, {inputTensor}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); @@ -58,20 +52,18 @@ std::vector Detector::generate(const cv::Mat &inputImage, cv::Size Detector::calculateModelImageSize(int32_t methodInputWidth) { - utils::validateInputWidth(methodInputWidth, constants::kDetectorInputWidths, - "Detector"); + utils::validateInputWidth(methodInputWidth, constants::kDetectorInputWidths, "Detector"); std::string methodName = "forward_" + std::to_string(methodInputWidth); auto inputShapes = getAllInputShapes(methodName); std::vector modelInputShape = inputShapes[0]; - cv::Size modelInputSize = - cv::Size(modelInputShape[modelInputShape.size() - 1], - modelInputShape[modelInputShape.size() - 2]); + cv::Size modelInputSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); return modelInputSize; } -std::vector -Detector::postprocess(const Tensor &tensor, const cv::Size &modelInputSize) { +std::vector Detector::postprocess(const Tensor &tensor, + const cv::Size &modelInputSize) { /* The output of the model consists of two matrices (heat maps): 1. ScoreText(Score map) - The probability of a region containing character. @@ -79,22 +71,20 @@ Detector::postprocess(const Tensor &tensor, const cv::Size &modelInputSize) { group each character into a single instance (sequence) Both matrices are H/2xW/2 (400x400 or 640x640). */ - std::span tensorData(tensor.const_data_ptr(), - tensor.numel()); + std::span tensorData(tensor.const_data_ptr(), tensor.numel()); /* The output of the model is a matrix half the size of the input image containing two channels representing the heatmaps. */ auto [scoreTextMat, scoreAffinityMat] = utils::interleavedArrayToMats( - tensorData, - cv::Size(modelInputSize.width / 2, modelInputSize.height / 2)); + tensorData, cv::Size(modelInputSize.width / 2, modelInputSize.height / 2)); /* Heatmaps are then converted into list of bounding boxes. */ - std::vector bBoxesList = utils::getDetBoxesFromTextMap( - scoreTextMat, scoreAffinityMat, constants::kTextThreshold, - constants::kLinkThreshold, constants::kLowTextThreshold); + std::vector bBoxesList = + utils::getDetBoxesFromTextMap(scoreTextMat, scoreAffinityMat, constants::kTextThreshold, + constants::kLinkThreshold, constants::kLowTextThreshold); /* Bounding boxes are at first corresponding to the 400x400 size or 640x640. @@ -102,8 +92,8 @@ Detector::postprocess(const Tensor &tensor, const cv::Size &modelInputSize) { 1280x1280. To match this difference we has to scale by the proper factor (3.2 or 2.0). */ - const float restoreRatio = utils::calculateRestoreRatio( - scoreTextMat.rows, constants::kRecognizerImageSize); + const float restoreRatio = + utils::calculateRestoreRatio(scoreTextMat.rows, constants::kRecognizerImageSize); utils::restoreBboxRatio(bBoxesList, restoreRatio); /* Since every bounding box is processed separately by Recognition models, we'd @@ -111,10 +101,10 @@ Detector::postprocess(const Tensor &tensor, const cv::Size &modelInputSize) { process many words / full line at once. It is not only faster but also easier for Recognizer models than recognition of single characters. */ - bBoxesList = utils::groupTextBoxes( - bBoxesList, constants::kCenterThreshold, constants::kDistanceThreshold, - constants::kHeightThreshold, constants::kMinSideThreshold, - constants::kMaxSideThreshold, constants::kMaxWidth); + bBoxesList = + utils::groupTextBoxes(bBoxesList, constants::kCenterThreshold, constants::kDistanceThreshold, + constants::kHeightThreshold, constants::kMinSideThreshold, + constants::kMaxSideThreshold, constants::kMaxWidth); return bBoxesList; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h index dc17aa0742..6e53faba2e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h @@ -21,9 +21,8 @@ class Detector : public BaseModel { public: explicit Detector(const std::string &modelSource, std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] - virtual std::vector - generate(const cv::Mat &inputImage, int32_t inputWidth); + [[nodiscard("Registered non-void function")]] + virtual std::vector generate(const cv::Mat &inputImage, int32_t inputWidth); cv::Size calculateModelImageSize(int32_t methodInputWidth); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp index d3e6964a05..45e399a4d9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp @@ -10,8 +10,7 @@ namespace rnexecutorch::models::ocr { OCR::OCR(const std::string &detectorSource, const std::string &recognizerSource, - const std::string &symbols, - std::shared_ptr callInvoker) + const std::string &symbols, std::shared_ptr callInvoker) : detector(detectorSource, callInvoker), recognitionHandler(recognizerSource, symbols, callInvoker) {} @@ -34,10 +33,9 @@ std::vector OCR::runInference(cv::Mat image) { - coordinates of bounding box corresponding to the original image size - confidence score */ - std::vector result = - recognitionHandler.recognize(bboxesList, image, - cv::Size(constants::kRecognizerImageSize, - constants::kRecognizerImageSize)); + std::vector result = recognitionHandler.recognize( + bboxesList, image, + cv::Size(constants::kRecognizerImageSize, constants::kRecognizerImageSize)); return result; } @@ -51,8 +49,8 @@ std::vector OCR::generateFromString(std::string input) { return runInference(image); } -std::vector -OCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { +std::vector OCR::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); cv::Mat bgr; @@ -61,34 +59,28 @@ OCR::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData) { #elif defined(__ANDROID__) cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR); #else - throw RnExecutorchError( - RnExecutorchErrorCode::PlatformNotSupported, - "generateFromFrame is not supported on this platform"); + throw RnExecutorchError(RnExecutorchErrorCode::PlatformNotSupported, + "generateFromFrame is not supported on this platform"); #endif cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(bgr, orient); auto detections = runInference(rotated); for (auto &det : detections) { std::array corners = {det.bbox.p1, det.bbox.p2}; ::rnexecutorch::utils::inverseRotatePoints(corners, orient, rotated.size()); - det.bbox = {{std::min(corners[0].x, corners[1].x), - std::min(corners[0].y, corners[1].y)}, - {std::max(corners[0].x, corners[1].x), - std::max(corners[0].y, corners[1].y)}}; + det.bbox = {{std::min(corners[0].x, corners[1].x), std::min(corners[0].y, corners[1].y)}, + {std::max(corners[0].x, corners[1].x), std::max(corners[0].y, corners[1].y)}}; } return detections; } -std::vector -OCR::generateFromPixels(JSTensorViewIn pixelData) { +std::vector OCR::generateFromPixels(JSTensorViewIn pixelData) { cv::Mat image; - cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, - cv::COLOR_RGB2BGR); + cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, cv::COLOR_RGB2BGR); return runInference(image); } std::size_t OCR::getMemoryLowerBound() const noexcept { - return detector.getMemoryLowerBound() + - recognitionHandler.getMemoryLowerBound(); + return detector.getMemoryLowerBound() + recognitionHandler.getMemoryLowerBound(); } void OCR::unload() noexcept { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h index 719cb957c4..aedd77324b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h @@ -26,9 +26,8 @@ namespace models::ocr { class OCR final { public: - explicit OCR(const std::string &detectorSource, - const std::string &recognizerSource, const std::string &symbols, - std::shared_ptr callInvoker); + explicit OCR(const std::string &detectorSource, const std::string &recognizerSource, + const std::string &symbols, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector generateFromString(std::string input); [[nodiscard("Registered non-void function")]] std::vector diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp index e70e46ab9c..f8a8820946 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp @@ -4,15 +4,14 @@ #include namespace rnexecutorch::models::ocr { -RecognitionHandler::RecognitionHandler( - const std::string &recognizerSource, const std::string &symbols, - std::shared_ptr callInvoker) +RecognitionHandler::RecognitionHandler(const std::string &recognizerSource, + const std::string &symbols, + std::shared_ptr callInvoker) : converter(symbols), recognizer(recognizerSource, callInvoker) { memorySizeLowerBound = recognizer.getMemoryLowerBound(); } -std::pair, float> -RecognitionHandler::runModel(cv::Mat image) { +std::pair, float> RecognitionHandler::runModel(cv::Mat image) { // Note that the height of an image is always equal to 64. int32_t desiredWidth = utils::getDesiredWidth(image, false); @@ -27,8 +26,7 @@ void RecognitionHandler::processBBox(std::vector &boxList, Resize the cropped image to have height = 64 (height accepted by Recognizer). */ - auto croppedImage = - utils::cropImage(box, imgGray, constants::kRecognizerHeight); + auto croppedImage = utils::cropImage(box, imgGray, constants::kRecognizerHeight); if (croppedImage.empty()) { return; @@ -38,15 +36,13 @@ void RecognitionHandler::processBBox(std::vector &boxList, Cropped image is resized into the closest of on of three: 128x64, 256x64, 512x64. */ - croppedImage = - utils::normalizeForRecognizer(croppedImage, constants::kRecognizerHeight, - constants::kAdjustContrast, false); + croppedImage = utils::normalizeForRecognizer(croppedImage, constants::kRecognizerHeight, + constants::kAdjustContrast, false); auto [predictionIndices, confidenceScore] = this->runModel(croppedImage); if (confidenceScore < constants::kLowConfidenceThreshold) { cv::rotate(croppedImage, croppedImage, cv::ROTATE_180); - auto [rotatedPredictionIndices, rotatedConfidenceScore] = - runModel(croppedImage); + auto [rotatedPredictionIndices, rotatedConfidenceScore] = runModel(croppedImage); if (rotatedConfidenceScore > confidenceScore) { confidenceScore = rotatedConfidenceScore; predictionIndices = rotatedPredictionIndices; @@ -64,15 +60,14 @@ void RecognitionHandler::processBBox(std::vector &boxList, types::BBox transformedBbox{ {(box.bbox.p1.x - padLeft) * ratio, (box.bbox.p1.y - padTop) * ratio}, {(box.bbox.p2.x - padLeft) * ratio, (box.bbox.p2.y - padTop) * ratio}}; - boxList.emplace_back( - transformedBbox, - converter.decodeGreedy(predictionIndices, predictionIndices.size())[0], - confidenceScore); + boxList.emplace_back(transformedBbox, + converter.decodeGreedy(predictionIndices, predictionIndices.size())[0], + confidenceScore); } std::vector -RecognitionHandler::recognize(std::vector bboxesList, - cv::Mat &imgGray, cv::Size desiredSize) { +RecognitionHandler::recognize(std::vector bboxesList, cv::Mat &imgGray, + cv::Size desiredSize) { /* Recognition Handler accepts bboxesList corresponding to size 1280x1280, which is desiredSize. diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h index abdfe5ba97..b297324395 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h @@ -17,20 +17,17 @@ namespace rnexecutorch::models::ocr { class RecognitionHandler final { public: - explicit RecognitionHandler(const std::string &recognizer, - const std::string &symbols, + explicit RecognitionHandler(const std::string &recognizer, const std::string &symbols, std::shared_ptr callInvoker); - std::vector - recognize(std::vector bboxesList, cv::Mat &imgGray, - cv::Size desiredSize); + std::vector recognize(std::vector bboxesList, + cv::Mat &imgGray, cv::Size desiredSize); void unload() noexcept; std::size_t getMemoryLowerBound() const noexcept; private: std::pair, float> runModel(cv::Mat image); - void processBBox(std::vector &boxList, - types::DetectorBBox &box, cv::Mat &imgGray, - types::PaddingInfo ratioAndPadding); + void processBBox(std::vector &boxList, types::DetectorBBox &box, + cv::Mat &imgGray, types::PaddingInfo ratioAndPadding); std::size_t memorySizeLowerBound{0}; CTCLabelConverter converter; Recognizer recognizer; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp index b87bcb1a9e..b6a28085b6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp @@ -15,8 +15,8 @@ Recognizer::Recognizer(const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) {} -std::pair, float> -Recognizer::generate(const cv::Mat &grayImage, int32_t inputWidth) { +std::pair, float> Recognizer::generate(const cv::Mat &grayImage, + int32_t inputWidth) { /* In our pipeline we use three types of Recognizer, each designated to handle different image sizes: @@ -26,8 +26,7 @@ Recognizer::generate(const cv::Mat &grayImage, int32_t inputWidth) { The `generate` function as an argument accepts an image in grayscale already resized to the expected size. */ - utils::validateInputWidth(inputWidth, constants::kRecognizerInputWidths, - "Recognizer"); + utils::validateInputWidth(inputWidth, constants::kRecognizerInputWidths, "Recognizer"); std::string method_name = "forward_" + std::to_string(inputWidth); auto shapes = getAllInputShapes(method_name); @@ -36,16 +35,14 @@ Recognizer::generate(const cv::Mat &grayImage, int32_t inputWidth) { "OCR method takes no inputs: " + method_name); } std::vector tensorDims = shapes[0]; - TensorPtr inputTensor = - image_processing::getTensorFromMatrixGray(tensorDims, grayImage); + TensorPtr inputTensor = image_processing::getTensorFromMatrixGray(tensorDims, grayImage); auto forwardResult = BaseModel::execute(method_name, {inputTensor}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); return postprocess(forwardResult->at(0).toTensor()); } -std::pair, float> -Recognizer::postprocess(const Tensor &tensor) const { +std::pair, float> Recognizer::postprocess(const Tensor &tensor) const { /* Raw model returns a tensor with dimensions [ 1 x seqLen x alphabetSize ] where: @@ -67,8 +64,7 @@ Recognizer::postprocess(const Tensor &tensor) const { const int32_t alphabetSize = tensor.size(2); const int32_t numRows = tensor.numel() / alphabetSize; - cv::Mat resultMat(numRows, alphabetSize, CV_32F, - tensor.mutable_data_ptr()); + cv::Mat resultMat(numRows, alphabetSize, CV_32F, tensor.mutable_data_ptr()); auto probabilities = utils::softmax(resultMat); auto [maxVal, maxIndices] = utils::findMaxValuesIndices(probabilities); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.h index 9a8129eb63..1a3102ff0c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.h @@ -25,12 +25,10 @@ class Recognizer final : public BaseModel { public: explicit Recognizer(const std::string &modelSource, std::shared_ptr callInvoker); - [[nodiscard("Registered non-void function")]] - std::pair, float> generate(const cv::Mat &grayImage, - int32_t inputWidth); + [[nodiscard("Registered non-void function")]] + std::pair, float> generate(const cv::Mat &grayImage, int32_t inputWidth); private: - std::pair, float> - postprocess(const Tensor &tensor) const; + std::pair, float> postprocess(const Tensor &tensor) const; }; } // namespace rnexecutorch::models::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.cpp index d1d4d1b5a0..19757768fd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.cpp @@ -12,8 +12,7 @@ static std::array bboxToCorners(const types::BBox &bbox) { return {bbox.p1, {bbox.p2.x, bbox.p1.y}, bbox.p2, {bbox.p1.x, bbox.p2.y}}; } -std::pair interleavedArrayToMats(std::span data, - cv::Size size) { +std::pair interleavedArrayToMats(std::span data, cv::Size size) { cv::Mat mat1 = cv::Mat(size, CV_32F); cv::Mat mat2 = cv::Mat(size, CV_32F); @@ -33,20 +32,18 @@ std::pair interleavedArrayToMats(std::span data, // Create a segmentation map for the current component. // Background is 0, (black), foreground is 255 (white) -cv::Mat createSegmentMap(const cv::Mat &mask, cv::Size mapSize, - const int32_t segmentColor = 255) { +cv::Mat createSegmentMap(const cv::Mat &mask, cv::Size mapSize, const int32_t segmentColor = 255) { cv::Mat segMap = cv::Mat::zeros(mapSize, CV_8U); segMap.setTo(segmentColor, mask); return segMap; } void morphologicalOperations( - const cv::Mat &segMap, const cv::Mat &stats, int32_t i, int32_t area, - int32_t imgW, int32_t imgH, - int32_t iterations = 1, // iterations number of times dilation is applied. - cv::Size anchor = - cv::Point(-1, -1) // anchor position of the anchor within the element; - // default means that the anchor is at the center. + const cv::Mat &segMap, const cv::Mat &stats, int32_t i, int32_t area, int32_t imgW, + int32_t imgH, + int32_t iterations = 1, // iterations number of times dilation is applied. + cv::Size anchor = cv::Point(-1, -1) // anchor position of the anchor within the element; + // default means that the anchor is at the center. ) { const int32_t x = stats.at(i, cv::CC_STAT_LEFT); const int32_t y = stats.at(i, cv::CC_STAT_TOP); @@ -56,8 +53,7 @@ void morphologicalOperations( // Dynamically calculate dilation radius to expand the bounding box slightly constexpr int32_t evenMultiplyCoeff = 2; // ensure that dilationRadius is even const int32_t dilationRadius = static_cast( - std::sqrt(static_cast(area) / std::max(w, h)) * - evenMultiplyCoeff); + std::sqrt(static_cast(area) / std::max(w, h)) * evenMultiplyCoeff); const int32_t sx = std::max(x - dilationRadius, 0); const int32_t ex = std::min(x + w + dilationRadius, imgW); const int32_t sy = std::max(y - dilationRadius, 0); @@ -71,36 +67,28 @@ void morphologicalOperations( 1 + dilationRadius; // Ensures valid odd-sized kernel, // notice the fact that dilationRadius is always even. cv::Mat kernel = cv::getStructuringElement( - cv::MORPH_RECT, - cv::Size(morphologicalKernelSize, morphologicalKernelSize)); + cv::MORPH_RECT, cv::Size(morphologicalKernelSize, morphologicalKernelSize)); cv::Mat roiSegMap = segMap(roi); cv::dilate(roiSegMap, roiSegMap, kernel, anchor, iterations); } -types::DetectorBBox -extractMinAreaBBoxFromContour(const std::vector contour) { +types::DetectorBBox extractMinAreaBBoxFromContour(const std::vector contour) { cv::RotatedRect minRect = cv::minAreaRect(contour); std::array vertices; minRect.points(vertices.data()); - float minX = - std::min({vertices[0].x, vertices[1].x, vertices[2].x, vertices[3].x}); - float minY = - std::min({vertices[0].y, vertices[1].y, vertices[2].y, vertices[3].y}); - float maxX = - std::max({vertices[0].x, vertices[1].x, vertices[2].x, vertices[3].x}); - float maxY = - std::max({vertices[0].y, vertices[1].y, vertices[2].y, vertices[3].y}); + float minX = std::min({vertices[0].x, vertices[1].x, vertices[2].x, vertices[3].x}); + float minY = std::min({vertices[0].y, vertices[1].y, vertices[2].y, vertices[3].y}); + float maxX = std::max({vertices[0].x, vertices[1].x, vertices[2].x, vertices[3].x}); + float maxY = std::max({vertices[0].y, vertices[1].y, vertices[2].y, vertices[3].y}); types::BBox bbox = {{minX, minY}, {maxX, maxY}}; return {.bbox = bbox, .angle = minRect.angle}; } -void getBoxFromContour(cv::Mat &segMap, - std::vector &detectedBoxes) { +void getBoxFromContour(cv::Mat &segMap, std::vector &detectedBoxes) { std::vector> contours; - cv::findContours(segMap, contours, cv::RETR_EXTERNAL, - cv::CHAIN_APPROX_SIMPLE); + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); if (!contours.empty()) { detectedBoxes.emplace_back(extractMinAreaBBoxFromContour(contours[0])); } @@ -109,12 +97,11 @@ void getBoxFromContour(cv::Mat &segMap, // Function for processing single component. It is shared between the // VerticalOCR and standard OCR. param isVertical specifies which OCR uses it. // param lowTextThreshold is used only by standard OCR. -void processComponent(const cv::Mat &textMap, const cv::Mat &labels, - const cv::Mat &stats, int32_t i, int32_t imgW, - int32_t imgH, - std::vector &detectedBoxes, - bool isVertical, int32_t minimalAreaThreshold, - int32_t dilationIter, float lowTextThreshold = 0.0) { +void processComponent(const cv::Mat &textMap, const cv::Mat &labels, const cv::Mat &stats, + int32_t i, int32_t imgW, int32_t imgH, + std::vector &detectedBoxes, bool isVertical, + int32_t minimalAreaThreshold, int32_t dilationIter, + float lowTextThreshold = 0.0) { const int32_t area = stats.at(i, cv::CC_STAT_AREA); // Skip small components as they are likely to be just noise if (area < minimalAreaThreshold) { @@ -144,10 +131,9 @@ void processComponent(const cv::Mat &textMap, const cv::Mat &labels, getBoxFromContour(segMap, detectedBoxes); } -std::vector -getDetBoxesFromTextMap(cv::Mat &textMap, cv::Mat &affinityMap, - float textThreshold, float linkThreshold, - float lowTextThreshold) { +std::vector getDetBoxesFromTextMap(cv::Mat &textMap, cv::Mat &affinityMap, + float textThreshold, float linkThreshold, + float lowTextThreshold) { // Ensure input mats are of the correct type for processing CV_Assert(textMap.type() == CV_32F && affinityMap.type() == CV_32F); @@ -158,24 +144,21 @@ getDetBoxesFromTextMap(cv::Mat &textMap, cv::Mat &affinityMap, // 1. Based on maps and threshold values create binary masks constexpr double maxValBinaryMask = 1.0; - cv::threshold(textMap, textScore, textThreshold, maxValBinaryMask, - cv::THRESH_BINARY); - cv::threshold(affinityMap, affinityScore, linkThreshold, maxValBinaryMask, - cv::THRESH_BINARY); + cv::threshold(textMap, textScore, textThreshold, maxValBinaryMask, cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, maxValBinaryMask, cv::THRESH_BINARY); // 2. Merge two maps into one using logical OR cv::Mat textScoreComb = textScore + affinityScore; constexpr double threshVal = 0.0; - cv::threshold(textScoreComb, textScoreComb, threshVal, maxValBinaryMask, - cv::THRESH_BINARY); + cv::threshold(textScoreComb, textScoreComb, threshVal, maxValBinaryMask, cv::THRESH_BINARY); cv::Mat binaryMat; textScoreComb.convertTo(binaryMat, CV_8UC1); // 3. Find connected components to identify each box cv::Mat labels, stats, centroids; constexpr int32_t connectivityType = 4; - const int32_t nLabels = cv::connectedComponentsWithStats( - binaryMat, labels, stats, centroids, connectivityType); + const int32_t nLabels = + cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, connectivityType); std::vector detectedBoxes; detectedBoxes.reserve(nLabels); // Pre-allocate memory @@ -188,18 +171,16 @@ getDetBoxesFromTextMap(cv::Mat &textMap, cv::Mat &affinityMap, // 4. Process each component; omit component 0 as it is background for (int32_t i = 1; i < nLabels; i++) { - processComponent(textMap, labels, stats, i, imgW, imgH, detectedBoxes, - false, minimalAreaThreshold, dilationIter, - lowTextThreshold); + processComponent(textMap, labels, stats, i, imgW, imgH, detectedBoxes, false, + minimalAreaThreshold, dilationIter, lowTextThreshold); } return detectedBoxes; } std::vector -getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, - float textThreshold, float linkThreshold, - bool independentCharacters) { +getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, float textThreshold, + float linkThreshold, bool independentCharacters) { // Ensure input mats are of the correct type for processing CV_Assert(textMap.type() == CV_32F && affinityMap.type() == CV_32F); @@ -210,10 +191,8 @@ getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, // 1. Threshold text and affinity maps to create binary masks constexpr double maxValBinaryMask = 1.0; - cv::threshold(textMap, textScore, textThreshold, maxValBinaryMask, - cv::THRESH_BINARY); - cv::threshold(affinityMap, affinityScore, linkThreshold, maxValBinaryMask, - cv::THRESH_BINARY); + cv::threshold(textMap, textScore, textThreshold, maxValBinaryMask, cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, maxValBinaryMask, cv::THRESH_BINARY); // Prepare values for morphological operations const auto kSize = cv::Size(3, 3); // size of the structuring element @@ -224,17 +203,16 @@ getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, // iterations number of times dilation is applied. int32_t dilationIterations; - const auto anchor = - cv::Point(-1, -1); // anchor position of the anchor within the element; - // default value (-1, -1) - // means that the anchor is at the element center + const auto anchor = cv::Point(-1, -1); // anchor position of the anchor within the element; + // default value (-1, -1) + // means that the anchor is at the element center // 2. Combine maps based on whether we are detecting words or single // characters // For single characters, subtract affinity to separate adjacent chars, // otherwise add affinity to link characters together - cv::Mat textScoreComb = independentCharacters ? textScore - affinityScore - : textScore + affinityScore; + cv::Mat textScoreComb = + independentCharacters ? textScore - affinityScore : textScore + affinityScore; // Clamp values to be >= 0 cv::threshold(textScoreComb, textScoreComb, 0.0, 1.0, cv::THRESH_TOZERO); // Clamp values to be <= 1 @@ -255,8 +233,8 @@ getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, cv::Mat labels, stats, centroids; constexpr int32_t connectivityType = 4; - const int32_t nLabels = cv::connectedComponentsWithStats( - binaryMat, labels, stats, centroids, connectivityType); + const int32_t nLabels = + cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, connectivityType); std::vector detectedBoxes; detectedBoxes.reserve(nLabels); @@ -286,8 +264,7 @@ float calculateRestoreRatio(int32_t currentSize, int32_t desiredSize) { return desiredSize / static_cast(currentSize); } -void restoreBboxRatio(std::vector &boxes, - float restoreRatio) { +void restoreBboxRatio(std::vector &boxes, float restoreRatio) { for (auto &box : boxes) { box.bbox = box.bbox.scale(restoreRatio, restoreRatio); } @@ -299,26 +276,17 @@ float distanceFromPoint(const types::Point &p1, const types::Point &p2) { return std::hypot(xDist, yDist); } -float normalizeAngle(float angle) { - return (angle > 45.0f) ? (angle - 90.0f) : angle; -} +float normalizeAngle(float angle) { return (angle > 45.0f) ? (angle - 90.0f) : angle; } -types::Point midpointBetweenPoint(const types::Point &p1, - const types::Point &p2) { +types::Point midpointBetweenPoint(const types::Point &p1, const types::Point &p2) { return {.x = std::midpoint(p1.x, p2.x), .y = std::midpoint(p1.y, p2.y)}; } -types::Point centerOfBox(const types::BBox &box) { - return midpointBetweenPoint(box.p1, box.p2); -} +types::Point centerOfBox(const types::BBox &box) { return midpointBetweenPoint(box.p1, box.p2); } -float minSideLength(const types::BBox &bbox) { - return std::min(bbox.width(), bbox.height()); -} +float minSideLength(const types::BBox &bbox) { return std::min(bbox.width(), bbox.height()); } -float maxSideLength(const types::BBox &bbox) { - return std::max(bbox.width(), bbox.height()); -} +float maxSideLength(const types::BBox &bbox) { return std::max(bbox.width(), bbox.height()); } /** * This method calculates the distances between each sequential pair of points @@ -360,16 +328,13 @@ std::tuple fitLineToShortestSides(const types::BBox &bbox) { float m, c; bool isVertical; - std::array cvMidPoints = { - cv::Point2f(midpoint1.x, midpoint1.y), - cv::Point2f(midpoint2.x, midpoint2.y)}; + std::array cvMidPoints = {cv::Point2f(midpoint1.x, midpoint1.y), + cv::Point2f(midpoint2.x, midpoint2.y)}; cv::Vec4f line; // parameteres for fitLine calculation: - constexpr int32_t numericalParameter = - 0; // important only for some types of distances, O means an optimal value - // is chosen - constexpr double accuracy = - 0.01; // sufficient accuracy. Value proposed by OPENCV + constexpr int32_t numericalParameter = 0; // important only for some types of distances, O means + // an optimal value is chosen + constexpr double accuracy = 0.01; // sufficient accuracy. Value proposed by OPENCV isVertical = dx < constants::kVerticalLineThreshold; if (isVertical) { @@ -377,8 +342,7 @@ std::tuple fitLineToShortestSides(const types::BBox &bbox) { std::swap(pt.x, pt.y); } } - cv::fitLine(cvMidPoints, line, cv::DIST_L2, numericalParameter, accuracy, - accuracy); + cv::fitLine(cvMidPoints, line, cv::DIST_L2, numericalParameter, accuracy, accuracy); m = line[1] / line[0]; c = line[3] - m * line[2]; return {m, c, isVertical}; @@ -407,8 +371,7 @@ types::BBox rotateBox(const types::BBox &bbox, float angle) { return {{minX, minY}, {maxX, maxY}}; } -float calculateMinimalDistanceBetweenBox(const types::BBox &box1, - const types::BBox &box2) { +float calculateMinimalDistanceBetweenBox(const types::BBox &box1, const types::BBox &box2) { float minDistance = std::numeric_limits::max(); for (const auto &c1 : bboxToCorners(box1)) { for (const auto &c2 : bboxToCorners(box2)) { @@ -439,8 +402,7 @@ types::BBox orderPointsClockwise(const types::BBox &bbox) { {std::max(bbox.p1.x, bbox.p2.x), std::max(bbox.p1.y, bbox.p2.y)}}; } -types::BBox mergeRotatedBoxes(const types::BBox &box1, - const types::BBox &box2) { +types::BBox mergeRotatedBoxes(const types::BBox &box1, const types::BBox &box2) { return {{std::min(box1.p1.x, box2.p1.x), std::min(box1.p1.y, box2.p1.y)}, {std::max(box1.p2.x, box2.p2.x), std::max(box1.p2.y, box2.p2.y)}}; } @@ -471,9 +433,8 @@ types::BBox mergeRotatedBoxes(const types::BBox &box1, */ std::optional> findClosestBox(const std::vector &boxes, - const std::unordered_set &ignoredIdxs, - const types::BBox ¤tBox, bool isVertical, float m, float c, - float centerThreshold) { + const std::unordered_set &ignoredIdxs, const types::BBox ¤tBox, + bool isVertical, float m, float c, float centerThreshold) { float smallestDistance = std::numeric_limits::max(); ssize_t idx = -1; float boxHeight = 0.0f; @@ -495,10 +456,8 @@ findClosestBox(const std::vector &boxes, boxHeight = minSideLength(bbox); const float lineDistance = - isVertical ? std::fabs(centerOfProcessedBox.x - - (m * centerOfProcessedBox.y + c)) - : std::fabs(centerOfProcessedBox.y - - (m * centerOfProcessedBox.x + c)); + isVertical ? std::fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) + : std::fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c)); if (lineDistance < boxHeight * centerThreshold) { idx = i; @@ -506,8 +465,7 @@ findClosestBox(const std::vector &boxes, } } - return idx != -1 ? std::optional(std::make_pair(idx, boxHeight)) - : std::nullopt; + return idx != -1 ? std::optional(std::make_pair(idx, boxHeight)) : std::nullopt; } /** @@ -518,8 +476,8 @@ findClosestBox(const std::vector &boxes, * Otherwise, the box is excluded from the result. */ std::vector -removeSmallBoxesFromArray(const std::vector &boxes, - float minSideThreshold, float maxSideThreshold) { +removeSmallBoxesFromArray(const std::vector &boxes, float minSideThreshold, + float maxSideThreshold) { std::vector filteredBoxes; for (const auto &box : boxes) { @@ -536,14 +494,12 @@ removeSmallBoxesFromArray(const std::vector &boxes, static float minimumYFromBox(const types::BBox &bbox) { return bbox.p1.y; } static float minimumXFromBox(const types::BBox &bbox) { return bbox.p1.x; } -std::vector -groupTextBoxes(std::vector &boxes, float centerThreshold, - float distanceThreshold, float heightThreshold, - int32_t minSideThreshold, int32_t maxSideThreshold, - int32_t maxWidth) { +std::vector groupTextBoxes(std::vector &boxes, + float centerThreshold, float distanceThreshold, + float heightThreshold, int32_t minSideThreshold, + int32_t maxSideThreshold, int32_t maxWidth) { // Sort boxes descending by maximum side length - std::ranges::sort(boxes, [](const types::DetectorBBox &lhs, - const types::DetectorBBox &rhs) { + std::ranges::sort(boxes, [](const types::DetectorBBox &lhs, const types::DetectorBBox &rhs) { return maxSideLength(lhs.bbox) > maxSideLength(rhs.bbox); }); @@ -559,16 +515,14 @@ groupTextBoxes(std::vector &boxes, float centerThreshold, while (true) { // Find all aligned boxes and merge them until max_size is reached or no // more boxes can be merged - auto [slope, intercept, isVertical] = - fitLineToShortestSides(currentBox.bbox); + auto [slope, intercept, isVertical] = fitLineToShortestSides(currentBox.bbox); lineAngle = std::atan(slope) * 180.0f / M_PI; if (isVertical) { lineAngle = -90.0f; } - auto closestBoxInfo = - findClosestBox(boxes, ignoredIdxs, currentBox.bbox, isVertical, slope, - intercept, centerThreshold); + auto closestBoxInfo = findClosestBox(boxes, ignoredIdxs, currentBox.bbox, isVertical, slope, + intercept, centerThreshold); if (!closestBoxInfo.has_value()) { break; } @@ -580,12 +534,11 @@ groupTextBoxes(std::vector &boxes, float centerThreshold, candidateBox.bbox = rotateBox(candidateBox.bbox, normalizedAngle); } - const float minDistance = calculateMinimalDistanceBetweenBox( - candidateBox.bbox, currentBox.bbox); + const float minDistance = + calculateMinimalDistanceBetweenBox(candidateBox.bbox, currentBox.bbox); const float mergedHeight = minSideLength(currentBox.bbox); if (minDistance < distanceThreshold * candidateHeight && - std::fabs(mergedHeight - candidateHeight) < - candidateHeight * heightThreshold) { + std::fabs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold) { currentBox.bbox = mergeRotatedBoxes(currentBox.bbox, candidateBox.bbox); boxes.erase(boxes.begin() + candidateIdx); ignoredIdxs.clear(); @@ -600,8 +553,7 @@ groupTextBoxes(std::vector &boxes, float centerThreshold, } // Remove small boxes and sort by vertical - mergedVec = - removeSmallBoxesFromArray(mergedVec, minSideThreshold, maxSideThreshold); + mergedVec = removeSmallBoxesFromArray(mergedVec, minSideThreshold, maxSideThreshold); std::ranges::sort(mergedVec, [](const auto &obj1, const auto &obj2) { return minimumYFromBox(obj1.bbox) < minimumYFromBox(obj2.bbox); @@ -621,10 +573,9 @@ groupTextBoxes(std::vector &boxes, float centerThreshold, } for (auto rowBegin = mergedVec.begin(); rowBegin != mergedVec.end();) { const float rowY = minimumYFromBox(rowBegin->bbox); - auto rowEnd = - std::find_if(rowBegin, mergedVec.end(), [rowY, yThresh](const auto &b) { - return minimumYFromBox(b.bbox) - rowY > yThresh; - }); + auto rowEnd = std::find_if(rowBegin, mergedVec.end(), [rowY, yThresh](const auto &b) { + return minimumYFromBox(b.bbox) - rowY > yThresh; + }); std::sort(rowBegin, rowEnd, [](const auto &a, const auto &b) { return minimumXFromBox(a.bbox) < minimumXFromBox(b.bbox); }); @@ -648,13 +599,11 @@ void validateInputWidth(int32_t inputWidth, std::span constants, if (it == constants.end()) { std::string allowed; for (size_t i = 0; i < constants.size(); ++i) { - allowed += - std::to_string(constants[i]) + (i < constants.size() - 1 ? ", " : ""); + allowed += std::to_string(constants[i]) + (i < constants.size() - 1 ? ", " : ""); } - throw std::runtime_error("Unexpected input width for " + modelName + - "! Expected [" + allowed + "] but got " + - std::to_string(inputWidth) + "."); + throw std::runtime_error("Unexpected input width for " + modelName + "! Expected [" + allowed + + "] but got " + std::to_string(inputWidth) + "."); } } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.h index 0b742a4ce1..535d6d392d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/DetectorUtils.h @@ -5,8 +5,7 @@ #include namespace rnexecutorch::models::ocr::utils { -std::pair interleavedArrayToMats(std::span data, - cv::Size size); +std::pair interleavedArrayToMats(std::span data, cv::Size size); /** * This method applies a series of image processing operations to identify * likely areas of text in the textMap and return the bounding boxes for single @@ -27,20 +26,16 @@ std::pair interleavedArrayToMats(std::span data, * detected text box. * - "angle": a float representing the rotation angle of the box. */ -std::vector getDetBoxesFromTextMap(cv::Mat &textMap, - cv::Mat &affinityMap, - float textThreshold, - float linkThreshold, +std::vector getDetBoxesFromTextMap(cv::Mat &textMap, cv::Mat &affinityMap, + float textThreshold, float linkThreshold, float lowTextThreshold); std::vector -getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, - float textThreshold, float linkThreshold, - bool independentCharacters); +getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, float textThreshold, + float linkThreshold, bool independentCharacters); float calculateRestoreRatio(int32_t currentSize, int32_t desiredSize); -void restoreBboxRatio(std::vector &boxes, - float restoreRatio); +void restoreBboxRatio(std::vector &boxes, float restoreRatio); /** * This method processes a vector of DetectorBBox bounding boxes, each * containing details about individual text boxes, and attempts to group and @@ -73,11 +68,10 @@ void restoreBboxRatio(std::vector &boxes, * 3. Post-processing to remove any boxes that are too small. * 4. Sort the final array of boxes by their vertical positions. */ -std::vector -groupTextBoxes(std::vector &boxes, float centerThreshold, - float distanceThreshold, float heightThreshold, - int32_t minSideThreshold, int32_t maxSideThreshold, - int32_t maxWidth); +std::vector groupTextBoxes(std::vector &boxes, + float centerThreshold, float distanceThreshold, + float heightThreshold, int32_t minSideThreshold, + int32_t maxSideThreshold, int32_t maxWidth); /** * Validates if the provided image width is supported by the model. diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp index 0e50f1c038..836e613125 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.cpp @@ -4,8 +4,7 @@ #include namespace rnexecutorch::models::ocr::utils { -types::PaddingInfo calculateResizeRatioAndPaddings(cv::Size size, - cv::Size desiredSize) { +types::PaddingInfo calculateResizeRatioAndPaddings(cv::Size size, cv::Size desiredSize) { const auto newRatioH = static_cast(desiredSize.height) / size.height; const auto newRatioW = static_cast(desiredSize.width) / size.width; auto resizeRatio = std::min(newRatioH, newRatioW); @@ -27,21 +26,17 @@ types::PaddingInfo calculateResizeRatioAndPaddings(cv::Size size, } void computeRatioAndResize(cv::Mat &img, cv::Size size, int32_t modelHeight) { - auto ratio = - static_cast(size.width) / static_cast(size.height); + auto ratio = static_cast(size.width) / static_cast(size.height); cv::Size resizedSize; if (ratio < 1.0) { - resizedSize = - cv::Size(modelHeight, static_cast(modelHeight / ratio)); + resizedSize = cv::Size(modelHeight, static_cast(modelHeight / ratio)); } else { - resizedSize = - cv::Size(static_cast(modelHeight * ratio), modelHeight); + resizedSize = cv::Size(static_cast(modelHeight * ratio), modelHeight); } cv::resize(img, img, resizedSize, 0.0, 0.0, cv::INTER_LANCZOS4); } -cv::Mat cropImage(types::DetectorBBox box, cv::Mat &image, - int32_t modelHeight) { +cv::Mat cropImage(types::DetectorBBox box, cv::Mat &image, int32_t modelHeight) { const std::array points = {{ {box.bbox.p1.x, box.bbox.p1.y}, {box.bbox.p2.x, box.bbox.p1.y}, @@ -57,8 +52,7 @@ cv::Mat cropImage(types::DetectorBBox box, cv::Mat &image, cv::Point2f imageCenter(image.cols / 2.0f, image.rows / 2.0f); cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, box.angle, 1.0); cv::Mat rotatedImage; - cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), - cv::INTER_LINEAR); + cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), cv::INTER_LINEAR); constexpr int32_t rows = 4; constexpr int32_t cols = 2; @@ -89,9 +83,7 @@ cv::Mat cropImage(types::DetectorBBox box, cv::Mat &image, cv::Mat croppedImage = rotatedImage(boundingBox).clone(); - computeRatioAndResize(croppedImage, - cv::Size(boundingBox.width, boundingBox.height), - modelHeight); + computeRatioAndResize(croppedImage, cv::Size(boundingBox.width, boundingBox.height), modelHeight); return croppedImage; } @@ -136,12 +128,11 @@ int32_t getDesiredWidth(const cv::Mat &img, bool isVertical) { if (img.cols >= constants::kMediumRecognizerWidth) { return constants::kMediumRecognizerWidth; } - return isVertical ? constants::kSmallVerticalRecognizerWidth - : constants::kSmallRecognizerWidth; + return isVertical ? constants::kSmallVerticalRecognizerWidth : constants::kSmallRecognizerWidth; } -cv::Mat normalizeForRecognizer(const cv::Mat &image, int32_t modelHeight, - double adjustContrast, bool isVertical) { +cv::Mat normalizeForRecognizer(const cv::Mat &image, int32_t modelHeight, double adjustContrast, + bool isVertical) { auto img = image.clone(); if (adjustContrast > 0.0) { adjustContrastGrey(img, adjustContrast); @@ -149,8 +140,7 @@ cv::Mat normalizeForRecognizer(const cv::Mat &image, int32_t modelHeight, int32_t desiredWidth = getDesiredWidth(image, isVertical); - img = - image_processing::resizePadded(img, cv::Size(desiredWidth, modelHeight)); + img = image_processing::resizePadded(img, cv::Size(desiredWidth, modelHeight)); img.convertTo(img, CV_32F, 1.0f / 255.0f); img -= 0.5f; img *= 2.0f; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.h index e2dea2f7f9..3c2da2bdc0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognitionHandlerUtils.h @@ -12,8 +12,7 @@ namespace rnexecutorch::models::ocr::utils { * @return Struct containing the scaling factor and top/left padding amounts for * centering the image. */ -types::PaddingInfo calculateResizeRatioAndPaddings(cv::Size size, - cv::Size desiredSize); +types::PaddingInfo calculateResizeRatioAndPaddings(cv::Size size, cv::Size desiredSize); /** * @brief Resizes an image proportionally to match a target height while * maintaining aspect ratio. @@ -85,6 +84,5 @@ int32_t getDesiredWidth(const cv::Mat &img, bool isVertical); * - Normalized float32 values in [-1, 1] range */ cv::Mat normalizeForRecognizer(const cv::Mat &image, int32_t modelHeight, - double adjustContrast = 0.0, - bool isVertical = false); + double adjustContrast = 0.0, bool isVertical = false); } // namespace rnexecutorch::models::ocr::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.cpp index e959739ab1..a6295b0eb8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.cpp @@ -62,8 +62,7 @@ types::ValuesAndIndices findMaxValuesIndices(const cv::Mat &mat) { return result; } -float confidenceScore(const std::vector &values, - const std::vector &indices) { +float confidenceScore(const std::vector &values, const std::vector &indices) { float product = 1.0f; int32_t count = 0; @@ -85,8 +84,7 @@ float confidenceScore(const std::vector &values, cv::Rect extractBoundingBox(const BBox &bbox) { return cv::Rect(static_cast(bbox.p1.x), static_cast(bbox.p1.y), - static_cast(bbox.width()), - static_cast(bbox.height())); + static_cast(bbox.width()), static_cast(bbox.height())); } cv::Mat characterBitMask(const cv::Mat &img) { @@ -98,8 +96,7 @@ cv::Mat characterBitMask(const cv::Mat &img) { bool uniform = true; bool accumulate = false; - cv::calcHist(&img, 1, 0, cv::Mat(), histogram, 1, &histSize, &histRange, - uniform, accumulate); + cv::calcHist(&img, 1, 0, cv::Mat(), histogram, 1, &histSize, &histRange, uniform, accumulate); // Compare sum of darker (left half) vs brighter (right half) pixels. const int32_t midPoint = histSize / 2; @@ -111,8 +108,7 @@ cv::Mat characterBitMask(const cv::Mat &img) { for (int32_t i = midPoint; i < histSize; i++) { sumRight += histogram.at(i); } - const int32_t thresholdType = - (sumLeft < sumRight) ? cv::THRESH_BINARY_INV : cv::THRESH_BINARY; + const int32_t thresholdType = (sumLeft < sumRight) ? cv::THRESH_BINARY_INV : cv::THRESH_BINARY; // 2. Binarize using Otsu's method (auto threshold). cv::Mat thresh; @@ -120,16 +116,15 @@ cv::Mat characterBitMask(const cv::Mat &img) { // 3. Find the largest connected component near the center. cv::Mat labels, stats, centroids; - const int32_t numLabels = cv::connectedComponentsWithStats( - thresh, labels, stats, centroids, 8, CV_32S); + const int32_t numLabels = + cv::connectedComponentsWithStats(thresh, labels, stats, centroids, 8, CV_32S); const int32_t height = thresh.rows; const int32_t width = thresh.cols; const int32_t minX = constants::kSingleCharacterCenterThreshold * width; const int32_t maxX = (1 - constants::kSingleCharacterCenterThreshold) * width; const int32_t minY = constants::kSingleCharacterCenterThreshold * height; - const int32_t maxY = - (1 - constants::kSingleCharacterCenterThreshold) * height; + const int32_t maxY = (1 - constants::kSingleCharacterCenterThreshold) * height; int32_t selectedComponent = -1; int32_t maxArea = -1; @@ -138,9 +133,8 @@ cv::Mat characterBitMask(const cv::Mat &img) { const double cx = centroids.at(i, 0); const double cy = centroids.at(i, 1); - if ((minX < cx && cx < maxX && minY < cy && - cy < maxY && // check if centered - area > constants::kSingleCharacterMinSize) && // check if large enough + if ((minX < cx && cx < maxX && minY < cy && cy < maxY && // check if centered + area > constants::kSingleCharacterMinSize) && // check if large enough area > maxArea) { selectedComponent = i; maxArea = area; @@ -161,8 +155,7 @@ cv::Mat characterBitMask(const cv::Mat &img) { return resultImage; } -cv::Mat cropImageWithBoundingBox(const cv::Mat &img, const BBox &bbox, - const BBox &originalBbox, +cv::Mat cropImageWithBoundingBox(const cv::Mat &img, const BBox &bbox, const BBox &originalBbox, const types::PaddingInfo &paddings, const types::PaddingInfo &originalPaddings) { if (!originalBbox.isValid()) { @@ -207,16 +200,14 @@ cv::Mat cropImageWithBoundingBox(const cv::Mat &img, const BBox &bbox, } cv::Mat prepareForRecognition(const cv::Mat &originalImage, const BBox &bbox, - const BBox &originalBbox, - const types::PaddingInfo &paddings, + const BBox &originalBbox, const types::PaddingInfo &paddings, const types::PaddingInfo &originalPaddings) { - auto croppedChar = cropImageWithBoundingBox(originalImage, bbox, originalBbox, - paddings, originalPaddings); + auto croppedChar = + cropImageWithBoundingBox(originalImage, bbox, originalBbox, paddings, originalPaddings); cv::cvtColor(croppedChar, croppedChar, cv::COLOR_BGR2GRAY); cv::resize(croppedChar, croppedChar, - cv::Size(constants::kSmallVerticalRecognizerWidth, - constants::kRecognizerHeight), - 0, 0, cv::INTER_AREA); + cv::Size(constants::kSmallVerticalRecognizerWidth, constants::kRecognizerHeight), 0, 0, + cv::INTER_AREA); return croppedChar; } } // namespace rnexecutorch::models::ocr::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.h index d693193386..2a7b3712c9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/utils/RecognizerUtils.h @@ -31,8 +31,7 @@ cv::Rect extractBoundingBox(const types::BBox &bbox); * https://github.com/JaidedAI/EasyOCR/blob/c4f3cd7225efd4f85451bd8b4a7646ae9a092420/easyocr/recognition.py#L14 * @details 'Some say that it's a code, sent to us from god' */ -float confidenceScore(const std::vector &values, - const std::vector &indices); +float confidenceScore(const std::vector &values, const std::vector &indices); cv::Mat characterBitMask(const cv::Mat &img); @@ -59,9 +58,7 @@ cv::Mat cropImageWithBoundingBox(const cv::Mat &img, const types::BBox &bbox, * * @details it utilizes cropImageWithBoundingBox to perform specific cropping. */ -cv::Mat prepareForRecognition(const cv::Mat &originalImage, - const types::BBox &bbox, - const types::BBox &originalBbox, - const types::PaddingInfo &paddings, +cv::Mat prepareForRecognition(const cv::Mat &originalImage, const types::BBox &bbox, + const types::BBox &originalBbox, const types::PaddingInfo &paddings, const types::PaddingInfo &originalPaddings); } // namespace rnexecutorch::models::ocr::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.cpp index 03147cd468..fb8bfcefca 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.cpp @@ -11,28 +11,24 @@ namespace rnexecutorch::models::pose_estimation { -PoseEstimation::PoseEstimation(const std::string &modelSource, - std::vector normMean, +PoseEstimation::PoseEstimation(const std::string &modelSource, std::vector normMean, std::vector normStd, std::shared_ptr callInvoker) : VisionModel(modelSource, callInvoker) { if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { - log(LOG_LEVEL::Warn, - "normMean must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normMean must have 3 elements — ignoring provided value."); } if (normStd.size() == 3) { normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); } else if (!normStd.empty()) { - log(LOG_LEVEL::Warn, - "normStd must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normStd must have 3 elements — ignoring provided value."); } } PoseDetections PoseEstimation::postprocess(const std::vector &tensors, - cv::Size originalSize, - double detectionThreshold, + cv::Size originalSize, double detectionThreshold, double keypointThreshold) { // Output tensors (batch dim squeezed): // 0: boxes (Q, 4) - xyxy bbox in model input pixel space @@ -60,8 +56,7 @@ PoseDetections PoseEstimation::postprocess(const std::vector &tensors, static_cast(shape[shape.size() - 2])); float scaleX = static_cast(originalSize.width) / modelInputSize.width; - float scaleY = - static_cast(originalSize.height) / modelInputSize.height; + float scaleY = static_cast(originalSize.height) / modelInputSize.height; PoseDetections allDetections; @@ -96,8 +91,7 @@ PoseDetections PoseEstimation::postprocess(const std::vector &tensors, return allDetections; } -PoseDetections PoseEstimation::runInference(cv::Mat image, - double detectionThreshold, +PoseDetections PoseEstimation::runInference(cv::Mat image, double detectionThreshold, double keypointThreshold, const std::string &methodName) { @@ -117,29 +111,25 @@ PoseDetections PoseEstimation::runInference(cv::Mat image, auto inputShapes = getAllInputShapes(methodName); if (inputShapes.empty() || inputShapes[0].size() < 2) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - "Could not determine input shape for method: " + - methodName); + "Could not determine input shape for method: " + methodName); } modelInputShape_ = inputShapes[0]; cv::Mat resizedToModelInput = preprocess(image); auto inputTensor = (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix( - modelInputShape_, resizedToModelInput, *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(modelInputShape_, - resizedToModelInput); + ? image_processing::getTensorFromMatrix(modelInputShape_, resizedToModelInput, *normMean_, + *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, resizedToModelInput); auto executeResult = execute(methodName, {inputTensor}); if (!executeResult.ok()) { - throw RnExecutorchError(executeResult.error(), - "The model's " + methodName + - " method did not succeed. " - "Ensure the model input is correct."); + throw RnExecutorchError(executeResult.error(), "The model's " + methodName + + " method did not succeed. " + "Ensure the model input is correct."); } - return postprocess(executeResult.get(), originalSize, detectionThreshold, - keypointThreshold); + return postprocess(executeResult.get(), originalSize, detectionThreshold, keypointThreshold); } PoseDetections PoseEstimation::generateFromString(std::string imageSource, @@ -149,20 +139,16 @@ PoseDetections PoseEstimation::generateFromString(std::string imageSource, cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Mat imageRGB; cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB); - return runInference(std::move(imageRGB), detectionThreshold, - keypointThreshold, methodName); + return runInference(std::move(imageRGB), detectionThreshold, keypointThreshold, methodName); } -PoseDetections PoseEstimation::generateFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData, +PoseDetections PoseEstimation::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold, - double keypointThreshold, - std::string methodName) { + double keypointThreshold, std::string methodName) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = extractFromFrame(runtime, frameData); cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient); - auto detections = - runInference(rotated, detectionThreshold, keypointThreshold, methodName); + auto detections = runInference(rotated, detectionThreshold, keypointThreshold, methodName); for (auto &person : detections) { ::rnexecutorch::utils::inverseRotatePoints(person, orient, rotated.size()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h index 983519b34b..90d7bc4b28 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h @@ -12,38 +12,33 @@ namespace models::pose_estimation { class PoseEstimation : public VisionModel { public: PoseEstimation(const std::string &modelSource, std::vector normMean, - std::vector normStd, - std::shared_ptr callInvoker); + std::vector normStd, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] PoseDetections - generateFromString(std::string imageSource, double detectionThreshold, - double keypointThreshold, std::string methodName); + generateFromString(std::string imageSource, double detectionThreshold, double keypointThreshold, + std::string methodName); [[nodiscard("Registered non-void function")]] PoseDetections - generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, - double detectionThreshold, double keypointThreshold, - std::string methodName); + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, double detectionThreshold, + double keypointThreshold, std::string methodName); [[nodiscard("Registered non-void function")]] PoseDetections - generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold, - double keypointThreshold, std::string methodName); + generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold, double keypointThreshold, + std::string methodName); private: std::optional normMean_; std::optional normStd_; [[nodiscard("Registered non-void function")]] - PoseDetections runInference(cv::Mat image, double detectionThreshold, - double keypointThreshold, + PoseDetections runInference(cv::Mat image, double detectionThreshold, double keypointThreshold, const std::string &modelName); [[nodiscard("Registered non-void function")]] - PoseDetections postprocess(const std::vector &evl, - cv::Size originalSize, double detectionThreshold, - double keypointThreshold); + PoseDetections postprocess(const std::vector &evl, cv::Size originalSize, + double detectionThreshold, double keypointThreshold); }; } // namespace models::pose_estimation -REGISTER_CONSTRUCTOR(models::pose_estimation::PoseEstimation, std::string, - std::vector, std::vector, - std::shared_ptr); +REGISTER_CONSTRUCTOR(models::pose_estimation::PoseEstimation, std::string, std::vector, + std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.cpp b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.cpp index d05312e47c..8649bdaa79 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.cpp @@ -26,20 +26,16 @@ constexpr int64_t kPadTokenId = 199999; } // namespace -PrivacyFilter::PrivacyFilter(const std::string &modelSource, - const std::string &tokenizerSource, - std::vector labelNames, - std::vector viterbiBiases, +PrivacyFilter::PrivacyFilter(const std::string &modelSource, const std::string &tokenizerSource, + std::vector labelNames, std::vector viterbiBiases, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker), - tokenizer_( - std::make_unique(tokenizerSource, callInvoker)), + tokenizer_(std::make_unique(tokenizerSource, callInvoker)), labelNames_(std::move(labelNames)), seqLen_(0) { if (labelNames_.empty() || labelNames_[0] != "O") { - throw RnExecutorchError( - RnExecutorchErrorCode::UnknownError, - "PrivacyFilter requires a non-empty labelNames vector " - "(must include 'O' at index 0)."); + throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, + "PrivacyFilter requires a non-empty labelNames vector " + "(must include 'O' at index 0)."); } if (!viterbiBiases.empty() && viterbiBiases.size() != 6) { throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, @@ -55,8 +51,7 @@ PrivacyFilter::PrivacyFilter(const std::string &modelSource, biases_.insideToEnd = viterbiBiases[5]; } auto inputShapes = getAllInputShapes(); - if (inputShapes.empty() || inputShapes[0].size() < 2 || - inputShapes[0][1] < 2) { + if (inputShapes.empty() || inputShapes[0].size() < 2 || inputShapes[0][1] < 2) { throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, "PrivacyFilter: expected forward input shape " "[1, seq_len] with seq_len >= 2."); @@ -83,22 +78,19 @@ void PrivacyFilter::unload() noexcept { } void PrivacyFilter::runWindow(std::vector &paddedInputIds, - std::vector &paddedAttentionMask, - int32_t absStart, int32_t validLen, - int32_t writeFromOffset, int32_t writeToOffset, + std::vector &paddedAttentionMask, int32_t absStart, + int32_t validLen, int32_t writeFromOffset, int32_t writeToOffset, std::vector &outLabels) { if (validLen <= 0) { return; } std::vector idsShape = {1, seqLen_}; - auto inputIdsTensor = - make_tensor_ptr(idsShape, paddedInputIds.data(), ScalarType::Long); + auto inputIdsTensor = make_tensor_ptr(idsShape, paddedInputIds.data(), ScalarType::Long); auto attentionMaskTensor = make_tensor_ptr(idsShape, paddedAttentionMask.data(), ScalarType::Long); - auto forwardResult = - BaseModel::forward({*inputIdsTensor, *attentionMaskTensor}); + auto forwardResult = BaseModel::forward({*inputIdsTensor, *attentionMaskTensor}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); auto &out = forwardResult.get(); if (out.empty()) { @@ -120,8 +112,7 @@ std::vector PrivacyFilter::generate(std::string text) { std::scoped_lock lock(inference_mutex_); if (!module_) { - throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, - "PrivacyFilter is not loaded"); + throw RnExecutorchError(RnExecutorchErrorCode::ModuleNotLoaded, "PrivacyFilter is not loaded"); } auto rawIds = tokenizer_->encode(text); @@ -131,12 +122,10 @@ std::vector PrivacyFilter::generate(std::string text) { const int32_t stride = seqLen_ / 2; const int32_t edgeMargin = seqLen_ / 4; - for (int32_t windowStart = 0; windowStart < totalTokens; - windowStart += stride) { + for (int32_t windowStart = 0; windowStart < totalTokens; windowStart += stride) { const int32_t validLen = std::min(seqLen_, totalTokens - windowStart); - std::vector paddedInputIds(static_cast(seqLen_), - kPadTokenId); + std::vector paddedInputIds(static_cast(seqLen_), kPadTokenId); std::vector paddedAttentionMask(static_cast(seqLen_), 0); for (int32_t i = 0; i < validLen; ++i) { paddedInputIds[static_cast(i)] = @@ -149,8 +138,8 @@ std::vector PrivacyFilter::generate(std::string text) { int32_t writeFrom = isFirst ? 0 : edgeMargin; int32_t writeTo = isLast ? validLen : seqLen_ - edgeMargin; - runWindow(paddedInputIds, paddedAttentionMask, windowStart, validLen, - writeFrom, writeTo, predictedLabels); + runWindow(paddedInputIds, paddedAttentionMask, windowStart, validLen, writeFrom, writeTo, + predictedLabels); if (isLast) { break; @@ -165,15 +154,13 @@ std::vector PrivacyFilter::generate(std::string text) { std::vector spans; int32_t i = 0; while (i < totalTokens) { - const auto entity = - labelEntityType(predictedLabels[static_cast(i)]); + const auto entity = labelEntityType(predictedLabels[static_cast(i)]); if (entity.empty()) { ++i; continue; } int32_t j = i + 1; - while (j < totalTokens && - labelEntityType(predictedLabels[static_cast(j)]) == entity) { + while (j < totalTokens && labelEntityType(predictedLabels[static_cast(j)]) == entity) { ++j; } spans.emplace_back(i, j, entity); @@ -195,16 +182,14 @@ std::vector PrivacyFilter::generate(std::string text) { } constexpr auto notSpace = [](unsigned char c) { return !std::isspace(c); }; auto left = std::ranges::find_if(decoded, notSpace); - auto right = - std::ranges::find_if(decoded.rbegin(), decoded.rend(), notSpace).base(); + auto right = std::ranges::find_if(decoded.rbegin(), decoded.rend(), notSpace).base(); if (left < right) { decoded.assign(left, right); } else { decoded.clear(); } - entities.emplace_back(span.entity, std::move(decoded), span.start, - span.end); + entities.emplace_back(span.entity, std::move(decoded), span.start, span.end); } return entities; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.h b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.h index 880a9709f1..b16998da94 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/PrivacyFilter.h @@ -19,10 +19,8 @@ using namespace facebook; class PrivacyFilter final : public BaseModel { public: - PrivacyFilter(const std::string &modelSource, - const std::string &tokenizerSource, - std::vector labelNames, - std::vector viterbiBiases, + PrivacyFilter(const std::string &modelSource, const std::string &tokenizerSource, + std::vector labelNames, std::vector viterbiBiases, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector @@ -31,10 +29,9 @@ class PrivacyFilter final : public BaseModel { void unload() noexcept; private: - void runWindow(std::vector &paddedInputIds, - std::vector &paddedAttentionMask, int32_t absStart, - int32_t validLen, int32_t writeFromOffset, - int32_t writeToOffset, std::vector &outLabels); + void runWindow(std::vector &paddedInputIds, std::vector &paddedAttentionMask, + int32_t absStart, int32_t validLen, int32_t writeFromOffset, int32_t writeToOffset, + std::vector &outLabels); std::string labelEntityType(int32_t labelId) const; @@ -48,7 +45,7 @@ class PrivacyFilter final : public BaseModel { } // namespace models::privacy_filter -REGISTER_CONSTRUCTOR(models::privacy_filter::PrivacyFilter, std::string, - std::string, std::vector, std::vector, +REGISTER_CONSTRUCTOR(models::privacy_filter::PrivacyFilter, std::string, std::string, + std::vector, std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.cpp b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.cpp index 214ddbbea6..2db1a17b60 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.cpp @@ -33,8 +33,7 @@ bool isValidTransition(const LabelRole &prev, const LabelRole &nxt) { return nxt.prefix == 'O' || nxt.prefix == 'B' || nxt.prefix == 'S'; } if (prev.prefix == 'B' || prev.prefix == 'I') { - return (nxt.prefix == 'I' || nxt.prefix == 'E') && - nxt.entity == prev.entity; + return (nxt.prefix == 'I' || nxt.prefix == 'E') && nxt.entity == prev.entity; } return false; } @@ -49,8 +48,7 @@ float biasFor(const LabelRole &prev, const LabelRole &nxt, const Biases &b) { if ((prev.prefix == 'E' || prev.prefix == 'S') && nxt.prefix == 'O') { return b.endToBackground; } - if ((prev.prefix == 'E' || prev.prefix == 'S') && - (nxt.prefix == 'B' || nxt.prefix == 'S')) { + if ((prev.prefix == 'E' || prev.prefix == 'S') && (nxt.prefix == 'B' || nxt.prefix == 'S')) { return b.endToStart; } if ((prev.prefix == 'B' || prev.prefix == 'I') && nxt.prefix == 'I') { @@ -64,8 +62,7 @@ float biasFor(const LabelRole &prev, const LabelRole &nxt, const Biases &b) { } // namespace -Grammar buildGrammar(const std::vector &labelNames, - const Biases &biases) { +Grammar buildGrammar(const std::vector &labelNames, const Biases &biases) { const size_t N = labelNames.size(); std::vector roles; roles.reserve(N); @@ -81,22 +78,20 @@ Grammar buildGrammar(const std::vector &labelNames, for (size_t i = 0; i < N; ++i) { for (size_t j = 0; j < N; ++j) { if (isValidTransition(roles[i], roles[j])) { - grammar.transitionScore[i * N + j] = - biasFor(roles[i], roles[j], biases); + grammar.transitionScore[i * N + j] = biasFor(roles[i], roles[j], biases); } } } grammar.validStart.assign(N, false); for (size_t i = 0; i < N; ++i) { - grammar.validStart[i] = (roles[i].prefix == 'O' || roles[i].prefix == 'B' || - roles[i].prefix == 'S'); + grammar.validStart[i] = + (roles[i].prefix == 'O' || roles[i].prefix == 'B' || roles[i].prefix == 'S'); } return grammar; } -std::vector decode(const float *logits, int32_t validLen, - const Grammar &grammar) { +std::vector decode(const float *logits, int32_t validLen, const Grammar &grammar) { if (validLen <= 0) { return {}; } @@ -144,8 +139,7 @@ std::vector decode(const float *logits, int32_t validLen, path[static_cast(validLen) - 1] = static_cast(bestEnd); for (int32_t t = validLen - 1; t > 0; --t) { path[static_cast(t) - 1] = - bp[static_cast(t) * N + - static_cast(path[static_cast(t)])]; + bp[static_cast(t) * N + static_cast(path[static_cast(t)])]; } return path; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.h b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.h index ec0912c7fb..42bcc7db5c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/privacy_filter/Viterbi.h @@ -30,12 +30,10 @@ struct Grammar { size_t numLabels = 0; }; -Grammar buildGrammar(const std::vector &labelNames, - const Biases &biases); +Grammar buildGrammar(const std::vector &labelNames, const Biases &biases); // Run constrained Viterbi over [validLen, numLabels] logits and return the // best BIOES-grammar-valid label-id sequence (length validLen). -std::vector decode(const float *logits, int32_t validLen, - const Grammar &grammar); +std::vector decode(const float *logits, int32_t validLen, const Grammar &grammar); } // namespace rnexecutorch::models::privacy_filter::viterbi diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp index ed88434533..74ad5747e0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp @@ -11,24 +11,22 @@ namespace rnexecutorch::models::semantic_segmentation { -BaseSemanticSegmentation::BaseSemanticSegmentation( - const std::string &modelSource, std::vector normMean, - std::vector normStd, std::vector allClasses, - std::shared_ptr callInvoker) - : VisionModel(modelSource, callInvoker), - allClasses_(std::move(allClasses)) { +BaseSemanticSegmentation::BaseSemanticSegmentation(const std::string &modelSource, + std::vector normMean, + std::vector normStd, + std::vector allClasses, + std::shared_ptr callInvoker) + : VisionModel(modelSource, callInvoker), allClasses_(std::move(allClasses)) { initModelImageSize(); if (normMean.size() == 3) { normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]); } else if (!normMean.empty()) { - log(LOG_LEVEL::Warn, - "normMean must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normMean must have 3 elements — ignoring provided value."); } if (normStd.size() == 3) { normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]); } else if (!normStd.empty()) { - log(LOG_LEVEL::Warn, - "normStd must have 3 elements — ignoring provided value."); + log(LOG_LEVEL::Warn, "normStd must have 3 elements — ignoring provided value."); } } @@ -49,30 +47,26 @@ void BaseSemanticSegmentation::initModelImageSize() { } semantic_segmentation::SegmentationResult -BaseSemanticSegmentation::runInference( - cv::Mat image, cv::Size originalSize, - std::set> &classesOfInterest, bool resize) { +BaseSemanticSegmentation::runInference(cv::Mat image, cv::Size originalSize, + std::set> &classesOfInterest, + bool resize) { std::scoped_lock lock(inference_mutex_); cv::Mat preprocessed = VisionModel::preprocess(image); - auto inputTensor = - (normMean_ && normStd_) - ? image_processing::getTensorFromMatrix( - modelInputShape_, preprocessed, *normMean_, *normStd_) - : image_processing::getTensorFromMatrix(modelInputShape_, - preprocessed); + auto inputTensor = (normMean_ && normStd_) + ? image_processing::getTensorFromMatrix(modelInputShape_, preprocessed, + *normMean_, *normStd_) + : image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); - return computeResult(forwardResult->at(0).toTensor(), originalSize, - allClasses_, classesOfInterest, resize); + return computeResult(forwardResult->at(0).toTensor(), originalSize, allClasses_, + classesOfInterest, resize); } -semantic_segmentation::SegmentationResult -BaseSemanticSegmentation::generateFromString( - std::string imageSource, - std::set> classesOfInterest, bool resize) { +semantic_segmentation::SegmentationResult BaseSemanticSegmentation::generateFromString( + std::string imageSource, std::set> classesOfInterest, bool resize) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Size originalSize = imageBGR.size(); cv::Mat imageRGB; @@ -81,18 +75,16 @@ BaseSemanticSegmentation::generateFromString( return runInference(imageRGB, originalSize, classesOfInterest, resize); } -semantic_segmentation::SegmentationResult -BaseSemanticSegmentation::generateFromPixels( - JSTensorViewIn pixelData, - std::set> classesOfInterest, bool resize) { +semantic_segmentation::SegmentationResult BaseSemanticSegmentation::generateFromPixels( + JSTensorViewIn pixelData, std::set> classesOfInterest, bool resize) { cv::Mat image = extractFromPixels(pixelData); return runInference(image, image.size(), classesOfInterest, resize); } semantic_segmentation::SegmentationResult -BaseSemanticSegmentation::generateFromFrame( - jsi::Runtime &runtime, const jsi::Value &frameData, - std::set> classesOfInterest, bool resize) { +BaseSemanticSegmentation::generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, + std::set> classesOfInterest, + bool resize) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = extractFromFrame(runtime, frameData); cv::Mat rotated = utils::rotateFrameForModel(frame, orient); @@ -104,15 +96,15 @@ BaseSemanticSegmentation::generateFromFrame( const cv::Size frameSize = frame.size(); auto inverseAndResize = [&orient, &frameSize, &outputSize, - resize](std::shared_ptr &buf, - int32_t cvType, int32_t interpFlag) { + resize](std::shared_ptr &buf, int32_t cvType, + int32_t interpFlag) { cv::Mat m(outputSize, cvType, buf->data()); cv::Mat inv = utils::inverseRotateMat(m, orient); if (resize && inv.size() != frameSize) { cv::resize(inv, inv, frameSize, 0, 0, interpFlag); } - buf = std::make_shared( - inv.data, static_cast(inv.total() * inv.elemSize())); + buf = std::make_shared(inv.data, + static_cast(inv.total() * inv.elemSize())); }; if (outputSize.area() > 0) { @@ -129,18 +121,15 @@ BaseSemanticSegmentation::generateFromFrame( return result; } -semantic_segmentation::SegmentationResult -BaseSemanticSegmentation::computeResult( - const Tensor &tensor, cv::Size originalSize, - std::vector &allClasses, +semantic_segmentation::SegmentationResult BaseSemanticSegmentation::computeResult( + const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, std::set> &classesOfInterest, bool resize) { const auto *dataPtr = tensor.const_data_ptr(); auto resultData = std::span(dataPtr, tensor.numel()); // Read output dimensions directly from tensor shape - std::size_t numChannels = - (tensor.dim() >= 3) ? tensor.size(tensor.dim() - 3) : 1; + std::size_t numChannels = (tensor.dim() >= 3) ? tensor.size(tensor.dim() - 3) : 1; std::size_t outputH = tensor.size(tensor.dim() - 2); std::size_t outputW = tensor.size(tensor.dim() - 1); std::size_t outputPixels = outputH * outputW; @@ -152,8 +141,7 @@ BaseSemanticSegmentation::computeResult( if (numChannels == 1) { // Binary segmentation (e.g. selfie segmentation) - auto fg = std::make_shared(resultData.data(), - outputPixels * sizeof(float)); + auto fg = std::make_shared(resultData.data(), outputPixels * sizeof(float)); auto bg = std::make_shared(outputPixels * sizeof(float)); auto *fgPtr = reinterpret_cast(fg->data()); auto *bgPtr = reinterpret_cast(bg->data()); @@ -171,8 +159,7 @@ BaseSemanticSegmentation::computeResult( } // Softmax + argmax in class-major order - auto argmax = - std::make_shared(outputPixels * sizeof(int32_t)); + auto argmax = std::make_shared(outputPixels * sizeof(int32_t)); auto *argmaxPtr = reinterpret_cast(argmax->data()); if (numChannels == 1) { @@ -181,8 +168,7 @@ BaseSemanticSegmentation::computeResult( argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 0 : 1; } } else { - std::vector maxLogits(outputPixels, - -std::numeric_limits::infinity()); + std::vector maxLogits(outputPixels, -std::numeric_limits::infinity()); std::vector sumExp(outputPixels, 0.0f); // Pass 1: find per-pixel max and argmax @@ -214,8 +200,8 @@ BaseSemanticSegmentation::computeResult( } } - auto buffersToReturn = std::make_shared< - std::unordered_map>>(); + auto buffersToReturn = + std::make_shared>>(); bool returnAllClasses = classesOfInterest.empty(); for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) { if (cl < allClasses.size() && @@ -227,16 +213,15 @@ BaseSemanticSegmentation::computeResult( // Resize selected classes and argmax if (resize) { cv::Mat argmaxMat(outputSize, CV_32SC1, argmax->data()); - cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, - cv::InterpolationFlags::INTER_NEAREST); - argmax = std::make_shared( - argmaxMat.data, originalSize.area() * sizeof(int32_t)); + cv::resize(argmaxMat, argmaxMat, originalSize, 0, 0, cv::InterpolationFlags::INTER_NEAREST); + argmax = + std::make_shared(argmaxMat.data, originalSize.area() * sizeof(int32_t)); for (auto &[label, arrayBuffer] : *buffersToReturn) { cv::Mat classMat(outputSize, CV_32FC1, arrayBuffer->data()); cv::resize(classMat, classMat, originalSize); - arrayBuffer = std::make_shared( - classMat.data, originalSize.area() * sizeof(float)); + arrayBuffer = + std::make_shared(classMat.data, originalSize.area() * sizeof(float)); } } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h index a30ae375bf..3414b3eb52 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h @@ -16,36 +16,29 @@ using executorch::aten::Tensor; class BaseSemanticSegmentation : public VisionModel { public: - BaseSemanticSegmentation(const std::string &modelSource, - std::vector normMean, - std::vector normStd, - std::vector allClasses, + BaseSemanticSegmentation(const std::string &modelSource, std::vector normMean, + std::vector normStd, std::vector allClasses, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] semantic_segmentation::SegmentationResult - generateFromString(std::string imageSource, - std::set> classesOfInterest, + generateFromString(std::string imageSource, std::set> classesOfInterest, bool resize); [[nodiscard("Registered non-void function")]] semantic_segmentation::SegmentationResult - generateFromPixels(JSTensorViewIn pixelData, - std::set> classesOfInterest, + generateFromPixels(JSTensorViewIn pixelData, std::set> classesOfInterest, bool resize); [[nodiscard("Registered non-void function")]] semantic_segmentation::SegmentationResult generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData, - std::set> classesOfInterest, - bool resize); + std::set> classesOfInterest, bool resize); protected: virtual semantic_segmentation::SegmentationResult - computeResult(const Tensor &tensor, cv::Size originalSize, - std::vector &allClasses, - std::set> &classesOfInterest, - bool resize); + computeResult(const Tensor &tensor, cv::Size originalSize, std::vector &allClasses, + std::set> &classesOfInterest, bool resize); std::size_t numModelPixels; std::optional normMean_; std::optional normStd_; @@ -56,13 +49,11 @@ class BaseSemanticSegmentation : public VisionModel { semantic_segmentation::SegmentationResult runInference(cv::Mat image, cv::Size originalSize, - std::set> &classesOfInterest, - bool resize); + std::set> &classesOfInterest, bool resize); }; } // namespace models::semantic_segmentation -REGISTER_CONSTRUCTOR(models::semantic_segmentation::BaseSemanticSegmentation, - std::string, std::vector, std::vector, - std::vector, +REGISTER_CONSTRUCTOR(models::semantic_segmentation::BaseSemanticSegmentation, std::string, + std::vector, std::vector, std::vector, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h index b305b96a70..d9517f9ce8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/Types.h @@ -9,9 +9,7 @@ namespace rnexecutorch::models::semantic_segmentation { struct SegmentationResult { std::shared_ptr argmax; - std::shared_ptr< - std::unordered_map>> - classBuffers; + std::shared_ptr>> classBuffers; }; } // namespace rnexecutorch::models::semantic_segmentation diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp index f3ec9f755f..13c911d7ce 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.cpp @@ -9,10 +9,8 @@ namespace rnexecutorch::models::speech_to_text { -SpeechToText::SpeechToText(const std::string &modelName, - const std::string &modelSource, - const std::string &tokenizerSource, - const std::string &vadSource, +SpeechToText::SpeechToText(const std::string &modelName, const std::string &modelSource, + const std::string &tokenizerSource, const std::string &vadSource, std::shared_ptr callInvoker) : callInvoker_(std::move(callInvoker)) { // Switch between the ASR implementations based on model name @@ -21,49 +19,39 @@ SpeechToText::SpeechToText(const std::string &modelName, vad_ = std::make_unique(vadSource, callInvoker_); } - transcriber_ = std::make_unique(modelSource, tokenizerSource, - callInvoker_); + transcriber_ = std::make_unique(modelSource, tokenizerSource, callInvoker_); streamer_ = std::make_unique( static_cast(transcriber_.get()), static_cast(vad_.get())); } else { - throw rnexecutorch::RnExecutorchError( - rnexecutorch::RnExecutorchErrorCode::InvalidConfig, - "[SpeechToText]: Invalid model name: " + modelName); + throw rnexecutorch::RnExecutorchError(rnexecutorch::RnExecutorchErrorCode::InvalidConfig, + "[SpeechToText]: Invalid model name: " + modelName); } } SpeechToText::SpeechToText(SpeechToText &&other) noexcept - : callInvoker_(std::move(other.callInvoker_)), - transcriber_(std::move(other.transcriber_)), - streamer_(std::move(other.streamer_)), - isStreaming_(other.isStreaming_.load()), + : callInvoker_(std::move(other.callInvoker_)), transcriber_(std::move(other.transcriber_)), + streamer_(std::move(other.streamer_)), isStreaming_(other.isStreaming_.load()), readyToProcess_(other.readyToProcess_.load()) {} void SpeechToText::unload() noexcept { transcriber_->unload(); } -std::shared_ptr -SpeechToText::encode(std::span waveform) const { +std::shared_ptr SpeechToText::encode(std::span waveform) const { executorch::aten::Tensor encoderOutputTensor = transcriber_->encode(waveform); - return std::make_shared( - encoderOutputTensor.const_data_ptr(), - sizeof(float) * encoderOutputTensor.numel()); + return std::make_shared(encoderOutputTensor.const_data_ptr(), + sizeof(float) * encoderOutputTensor.numel()); } -std::shared_ptr -SpeechToText::decode(std::span tokens, - std::span encoderOutput) const { - executorch::aten::Tensor decoderOutputTensor = - transcriber_->decode(tokens, encoderOutput); +std::shared_ptr SpeechToText::decode(std::span tokens, + std::span encoderOutput) const { + executorch::aten::Tensor decoderOutputTensor = transcriber_->decode(tokens, encoderOutput); - return std::make_shared( - decoderOutputTensor.const_data_ptr(), - sizeof(float) * decoderOutputTensor.numel()); + return std::make_shared(decoderOutputTensor.const_data_ptr(), + sizeof(float) * decoderOutputTensor.numel()); } -TranscriptionResult SpeechToText::transcribe(std::span waveform, - std::string languageOption, +TranscriptionResult SpeechToText::transcribe(std::span waveform, std::string languageOption, bool verbose) const { DecodingOptions options(languageOption, verbose); std::vector segments = transcriber_->transcribe(waveform, options); @@ -92,8 +80,8 @@ size_t SpeechToText::getMemoryLowerBound() const noexcept { } namespace { -TranscriptionResult wordsToResult(const std::vector &words, - const std::string &language, bool verbose) { +TranscriptionResult wordsToResult(const std::vector &words, const std::string &language, + bool verbose) { TranscriptionResult res; res.language = language; res.task = "stream"; @@ -120,9 +108,8 @@ TranscriptionResult wordsToResult(const std::vector &words, } } // namespace -void SpeechToText::stream(std::shared_ptr callback, - std::string languageOption, bool verbose, - uint32_t timeout, bool useVAD, +void SpeechToText::stream(std::shared_ptr callback, std::string languageOption, + bool verbose, uint32_t timeout, bool useVAD, uint32_t vadDetectionMargin) { if (isStreaming_) { throw RnExecutorchError(RnExecutorchErrorCode::StreamingInProgress, @@ -135,21 +122,16 @@ void SpeechToText::stream(std::shared_ptr callback, "Attempting to use VAD but it's not initialized!"); } - auto nativeCallback = - [this, callback](const TranscriptionResult &committed, - const TranscriptionResult &nonCommitted, bool isDone) { - // This moves execution to the JS thread - callInvoker_->invokeAsync( - [callback, committed, nonCommitted, isDone](jsi::Runtime &rt) { - jsi::Value jsiCommitted = - rnexecutorch::jsi_conversion::getJsiValue(committed, rt); - jsi::Value jsiNonCommitted = - rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt); - - callback->call(rt, std::move(jsiCommitted), - std::move(jsiNonCommitted), jsi::Value(isDone)); - }); - }; + auto nativeCallback = [this, callback](const TranscriptionResult &committed, + const TranscriptionResult &nonCommitted, bool isDone) { + // This moves execution to the JS thread + callInvoker_->invokeAsync([callback, committed, nonCommitted, isDone](jsi::Runtime &rt) { + jsi::Value jsiCommitted = rnexecutorch::jsi_conversion::getJsiValue(committed, rt); + jsi::Value jsiNonCommitted = rnexecutorch::jsi_conversion::getJsiValue(nonCommitted, rt); + + callback->call(rt, std::move(jsiCommitted), std::move(jsiNonCommitted), jsi::Value(isDone)); + }); + }; isStreaming_ = true; StreamingOptions options(languageOption, verbose, useVAD, vadDetectionMargin); @@ -162,8 +144,7 @@ void SpeechToText::stream(std::shared_ptr callback, // correctness when VAD is used. Otherwise we run into the vanishing text // issue. if (!res.committed.empty() || !res.nonCommitted.empty()) { - TranscriptionResult committedRes = - wordsToResult(res.committed, languageOption, verbose); + TranscriptionResult committedRes = wordsToResult(res.committed, languageOption, verbose); TranscriptionResult nonCommittedRes = wordsToResult(res.nonCommitted, languageOption, verbose); @@ -190,8 +171,7 @@ void SpeechToText::stream(std::shared_ptr callback, finishOptions.useVAD = false; std::vector finalWords = streamer_->finish(finishOptions); - TranscriptionResult finalRes = - wordsToResult(finalWords, languageOption, verbose); + TranscriptionResult finalRes = wordsToResult(finalWords, languageOption, verbose); nativeCallback(finalRes, {}, true); resetStreamState(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h index 9c084dcf6e..3fe4a2b09e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/SpeechToText.h @@ -28,22 +28,18 @@ class SpeechToText { SpeechToText(SpeechToText &&other) noexcept; void unload() noexcept; - [[nodiscard( - "Registered non-void function")]] std::shared_ptr + [[nodiscard("Registered non-void function")]] std::shared_ptr encode(std::span waveform) const; - [[nodiscard( - "Registered non-void function")]] std::shared_ptr + [[nodiscard("Registered non-void function")]] std::shared_ptr decode(std::span tokens, std::span encoderOutput) const; [[nodiscard("Registered non-void function")]] - TranscriptionResult transcribe(std::span waveform, - std::string languageOption, + TranscriptionResult transcribe(std::span waveform, std::string languageOption, bool verbose) const; size_t getMemoryLowerBound() const noexcept; // Stream - void stream(std::shared_ptr callback, - std::string languageOption, bool verbose, + void stream(std::shared_ptr callback, std::string languageOption, bool verbose, uint32_t timeout, bool useVAD, uint32_t vadDetectionMargin); void streamStop(); void streamInsert(std::span waveform); @@ -72,8 +68,7 @@ class SpeechToText { } // namespace models::speech_to_text -REGISTER_CONSTRUCTOR(models::speech_to_text::SpeechToText, std::string, - std::string, std::string, std::string, - std::shared_ptr); +REGISTER_CONSTRUCTOR(models::speech_to_text::SpeechToText, std::string, std::string, std::string, + std::string, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h index 61ca8b6871..ad541cddfb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/schema/ASR.h @@ -21,12 +21,10 @@ class ASR { public: virtual ~ASR() = default; - virtual std::vector - transcribe(std::span waveform, - const DecodingOptions &options) const = 0; + virtual std::vector transcribe(std::span waveform, + const DecodingOptions &options) const = 0; - virtual executorch::aten::Tensor - encode(std::span waveform) const = 0; + virtual executorch::aten::Tensor encode(std::span waveform) const = 0; virtual executorch::aten::Tensor decode(std::span tokens, std::span encoderOutput, @@ -38,4 +36,4 @@ class ASR { virtual std::size_t getMemoryLowerBound() const noexcept = 0; }; -} // namespace rnexecutorch::models::speech_to_text::schema \ No newline at end of file +} // namespace rnexecutorch::models::speech_to_text::schema diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Options.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Options.h index 760106d3f9..8e90a30d54 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Options.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Options.h @@ -7,18 +7,17 @@ namespace rnexecutorch::models::speech_to_text { struct DecodingOptions { DecodingOptions(const std::string &language, bool verbose = false) - : language(language.empty() ? std::nullopt : std::optional(language)), - verbose(verbose) {} + : language(language.empty() ? std::nullopt : std::optional(language)), verbose(verbose) {} std::optional language; bool verbose; }; struct StreamingOptions : public DecodingOptions { - StreamingOptions(const std::string &language, bool verbose = false, - bool useVAD = false, uint32_t vadDetectionMargin = 500) - : DecodingOptions(language, verbose), useVAD(useVAD), - vadDetectionMargin(vadDetectionMargin) {} + StreamingOptions(const std::string &language, bool verbose = false, bool useVAD = false, + uint32_t vadDetectionMargin = 500) + : DecodingOptions(language, verbose), useVAD(useVAD), vadDetectionMargin(vadDetectionMargin) { + } bool useVAD; uint32_t vadDetectionMargin; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h index 17c4a40914..db51efc2e1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/Token.h @@ -6,4 +6,4 @@ namespace rnexecutorch::models::speech_to_text { using Token = uint64_t; -} // namespace rnexecutorch::models::speech_to_text \ No newline at end of file +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h index 994cdb15eb..08b0ad6a37 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/common/types/TranscriptionResult.h @@ -13,4 +13,4 @@ struct TranscriptionResult { std::vector segments; // Populated only if verbose=true }; -} // namespace rnexecutorch::models::speech_to_text \ No newline at end of file +} // namespace rnexecutorch::models::speech_to_text diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp index a9f2b152b4..a3a467d574 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.cpp @@ -19,14 +19,10 @@ using executorch::runtime::etensor::ScalarType; ASR::ASR(const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker) : BaseModel(modelSource, std::move(callInvoker)), schema::ASR(), - tokenizer_(std::make_unique(tokenizerSource, - this->callInvoker)), - startOfTranscriptionToken_( - tokenizer_->tokenToId(constants::tokens::kStartOfTranscript)), - endOfTranscriptionToken_( - tokenizer_->tokenToId(constants::tokens::kEndOfTranscript)), - timestampBeginToken_( - tokenizer_->tokenToId(constants::tokens::kBeginTimestamp)) {} + tokenizer_(std::make_unique(tokenizerSource, this->callInvoker)), + startOfTranscriptionToken_(tokenizer_->tokenToId(constants::tokens::kStartOfTranscript)), + endOfTranscriptionToken_(tokenizer_->tokenToId(constants::tokens::kEndOfTranscript)), + timestampBeginToken_(tokenizer_->tokenToId(constants::tokens::kBeginTimestamp)) {} /** * Whisper inference - full transcription @@ -38,19 +34,17 @@ std::vector ASR::transcribe(std::span waveform, std::vector results; const float waveformSize = static_cast(waveform.size()); - const float waveformSkipBoundary = - static_cast((constants::kChunkSize - params::kChunkBreakBuffer) * - constants::kSamplingRate); + const float waveformSkipBoundary = static_cast( + (constants::kChunkSize - params::kChunkBreakBuffer) * constants::kSamplingRate); // We loop through the input audio waveform and process it in 30s chunks. // This is determined by Whisper models strict 30s audio length requirement. while (seek * constants::kSamplingRate < waveformSize) { // Calculate chunk bounds and extract the chunk. float start = seek * constants::kSamplingRate; - const auto end = - std::min(static_cast((seek + constants::kChunkSize) * - constants::kSamplingRate), - waveformSize); + const auto end = std::min( + static_cast((seek + constants::kChunkSize) * constants::kSamplingRate), + waveformSize); auto chunk = waveform.subspan(start, end - start); if (std::cmp_less(chunk.size(), constants::kMinChunkSamples)) { @@ -83,9 +77,8 @@ std::vector ASR::transcribe(std::span waveform, // This prevents additional segments to appear, unless the audio length is // very close to the max chunk size, that is there could be some words // spoken near the breakpoint. - seek = waveformSize < waveformSkipBoundary - ? seek + constants::kChunkSize - : segments.back().words.back().end; + seek = waveformSize < waveformSkipBoundary ? seek + constants::kChunkSize + : segments.back().words.back().end; } results.insert(results.end(), std::make_move_iterator(segments.begin()), std::make_move_iterator(segments.end())); @@ -104,15 +97,13 @@ executorch::aten::Tensor ASR::encode(std::span waveform) const { auto inputShape = {static_cast(waveform.size())}; const auto modelInputTensor = executorch::extension::make_tensor_ptr( - std::move(inputShape), const_cast(waveform.data()), - ScalarType::Float); + std::move(inputShape), const_cast(waveform.data()), ScalarType::Float); const auto encoderResult = this->execute("encode", {modelInputTensor}); if (!encoderResult.ok()) { - throw RnExecutorchError(encoderResult.error(), - "[Whisper] The 'encode' method did not succeed. " - "Ensure the model input is correct."); + throw RnExecutorchError(encoderResult.error(), "[Whisper] The 'encode' method did not succeed. " + "Ensure the model input is correct."); } return encoderResult.get().at(0).toTensor(); @@ -129,30 +120,26 @@ executorch::aten::Tensor ASR::decode(std::span tokens, std::vector tokenShape = {1, static_cast(tokens.size())}; std::vector positionShape = {static_cast(tokens.size())}; - auto tokenTensor = executorch::extension::make_tensor_ptr( - tokenShape, tokens.data(), ScalarType::Long); + auto tokenTensor = + executorch::extension::make_tensor_ptr(tokenShape, tokens.data(), ScalarType::Long); // Populate cache position vector std::vector cachePositions(tokens.size()); std::iota(cachePositions.begin(), cachePositions.end(), startPos); - auto positionTensor = executorch::extension::make_tensor_ptr( - positionShape, cachePositions.data(), ScalarType::Long); + auto positionTensor = executorch::extension::make_tensor_ptr(positionShape, cachePositions.data(), + ScalarType::Long); const auto encoderOutputSize = static_cast(encoderOutput.size()); - std::vector encShape = { - 1, static_cast(constants::kNumFrames), - encoderOutputSize / static_cast(constants::kNumFrames)}; + std::vector encShape = {1, static_cast(constants::kNumFrames), + encoderOutputSize / static_cast(constants::kNumFrames)}; auto encoderTensor = executorch::extension::make_tensor_ptr( - std::move(encShape), const_cast(encoderOutput.data()), - ScalarType::Float); + std::move(encShape), const_cast(encoderOutput.data()), ScalarType::Float); - const auto decoderResult = - this->execute("decode", {tokenTensor, positionTensor, encoderTensor}); + const auto decoderResult = this->execute("decode", {tokenTensor, positionTensor, encoderTensor}); if (!decoderResult.ok()) { - throw RnExecutorchError(decoderResult.error(), - "[Whisper] The 'decode' method did not succeed. " - "Ensure the model inputs are correct."); + throw RnExecutorchError(decoderResult.error(), "[Whisper] The 'decode' method did not succeed. " + "Ensure the model inputs are correct."); } return decoderResult.get().at(0).toTensor(); @@ -160,21 +147,17 @@ executorch::aten::Tensor ASR::decode(std::span tokens, void ASR::unload() noexcept { BaseModel::unload(); } -std::size_t ASR::getMemoryLowerBound() const noexcept { - return BaseModel::getMemoryLowerBound(); -} +std::size_t ASR::getMemoryLowerBound() const noexcept { return BaseModel::getMemoryLowerBound(); } /** * Helper functions - creating initial token IDs sequence */ -std::vector -ASR::createInitialSequence(const DecodingOptions &options) const { +std::vector ASR::createInitialSequence(const DecodingOptions &options) const { std::vector seq; seq.push_back(startOfTranscriptionToken_); if (options.language.has_value()) { - uint64_t langToken = - tokenizer_->tokenToId("<|" + options.language.value() + "|>"); + uint64_t langToken = tokenizer_->tokenToId("<|" + options.language.value() + "|>"); uint64_t taskToken = tokenizer_->tokenToId("<|transcribe|>"); seq.push_back(langToken); seq.push_back(taskToken); @@ -191,15 +174,13 @@ ASR::createInitialSequence(const DecodingOptions &options) const { std::vector ASR::generate(std::span waveform, const DecodingOptions &options) const { // A fixed pool of available temperatures - constexpr std::array temperatures = {0.0f, 0.2f, 0.4f, - 0.6f, 0.8f, 1.0f}; + constexpr std::array temperatures = {0.0f, 0.2f, 0.4f, 0.6f, 0.8f, 1.0f}; // Calculate audio features just once to save time. executorch::aten::Tensor encoderFeaturesTensor = this->encode(waveform); - const float *encoderFeaturesData = - encoderFeaturesTensor.const_data_ptr(); - std::span encoderFeatures( - encoderFeaturesData, encoderFeaturesData + encoderFeaturesTensor.numel()); + const float *encoderFeaturesData = encoderFeaturesTensor.const_data_ptr(); + std::span encoderFeatures(encoderFeaturesData, + encoderFeaturesData + encoderFeaturesTensor.numel()); std::vector bestTokens; float bestAvgLogProb = -std::numeric_limits::infinity(); @@ -207,16 +188,14 @@ std::vector ASR::generate(std::span waveform, float bestTemperature = 0.0f; for (auto t : temperatures) { - auto [tokens, scores] = - this->generate(waveform, options, t, {encoderFeatures}); + auto [tokens, scores] = this->generate(waveform, options, t, {encoderFeatures}); - const float cumLogProb = std::transform_reduce( - scores.begin(), scores.end(), 0.0f, std::plus<>(), - [](float s) { return std::log(std::max(s, 1e-9f)); }); + const float cumLogProb = + std::transform_reduce(scores.begin(), scores.end(), 0.0f, std::plus<>(), + [](float s) { return std::log(std::max(s, 1e-9f)); }); // Match whisper.cpp: divide by the number of summed log-probs. - const float avgLogProb = - cumLogProb / static_cast(std::max(1, scores.size())); + const float avgLogProb = cumLogProb / static_cast(std::max(1, scores.size())); const std::string text = tokenizer_->decode(tokens, true); const float compressionRatio = this->calculateCompressionRatio(text); @@ -236,28 +215,24 @@ std::vector ASR::generate(std::span waveform, } } - return this->calculateWordLevelTimestamps(bestTokens, waveform, - bestAvgLogProb, bestTemperature, + return this->calculateWordLevelTimestamps(bestTokens, waveform, bestAvgLogProb, bestTemperature, bestCompressionRatio); } /** * Helper functions - generation wrapper, single-temperature inference */ -GenerationResult -ASR::generate(std::span waveform, const DecodingOptions &options, - float temperature, - std::optional> encoderOutput) const { +GenerationResult ASR::generate(std::span waveform, const DecodingOptions &options, + float temperature, + std::optional> encoderOutput) const { std::span encoderFeatures; if (encoderOutput.has_value()) { encoderFeatures = encoderOutput.value(); } else { executorch::aten::Tensor encoderFeaturesTensor = this->encode(waveform); - const float *encoderFeaturesData = - encoderFeaturesTensor.const_data_ptr(); + const float *encoderFeaturesData = encoderFeaturesTensor.const_data_ptr(); encoderFeatures = - std::span(encoderFeaturesData, - encoderFeaturesData + encoderFeaturesTensor.numel()); + std::span(encoderFeaturesData, encoderFeaturesData + encoderFeaturesTensor.numel()); } std::vector sequenceIds = this->createInitialSequence(options); @@ -282,8 +257,8 @@ ASR::generate(std::span waveform, const DecodingOptions &options, while (std::cmp_less(startPos, constants::kMaxDecodeLength)) { const size_t logitsInnerDim = logitsTensor.size(1); const size_t logitsDictSize = logitsTensor.size(2); - const float *logitsData = logitsTensor.const_data_ptr() + - (logitsInnerDim - 1) * logitsDictSize; + const float *logitsData = + logitsTensor.const_data_ptr() + (logitsInnerDim - 1) * logitsDictSize; // Needs to be float* without const for compatibility with utility functions std::span logits(const_cast(logitsData), const_cast(logitsData) + @@ -326,17 +301,15 @@ ASR::generate(std::span waveform, const DecodingOptions &options, ++startPos; } - return {.tokens = std::vector(cachedTokens.cbegin() + - initialSequenceLenght, + return {.tokens = std::vector(cachedTokens.cbegin() + initialSequenceLenght, cachedTokens.cend()), .scores = scores}; } -std::vector -ASR::calculateWordLevelTimestamps(std::span generatedTokens, - const std::span waveform, - float avgLogProb, float temperature, - float compressionRatio) const { +std::vector ASR::calculateWordLevelTimestamps(std::span generatedTokens, + const std::span waveform, + float avgLogProb, float temperature, + float compressionRatio) const { const size_t generatedTokensSize = generatedTokens.size(); if (generatedTokensSize < 2 || generatedTokens[generatedTokensSize - 1] != endOfTranscriptionToken_ || @@ -399,8 +372,7 @@ ASR::calculateWordLevelTimestamps(std::span generatedTokens, float scalingFactor = static_cast(waveform.size()) / - (constants::kSamplingRate * (end - timestampBeginToken_) * - constants::kTimePrecision); + (constants::kSamplingRate * (end - timestampBeginToken_) * constants::kTimePrecision); if (scalingFactor < 1.0f) { for (auto &seg : segments) { for (auto &w : seg.words) { @@ -413,9 +385,8 @@ ASR::calculateWordLevelTimestamps(std::span generatedTokens, return segments; } -std::vector -ASR::estimateWordLevelTimestampsLinear(std::span tokens, - uint64_t start, uint64_t end) const { +std::vector ASR::estimateWordLevelTimestampsLinear(std::span tokens, + uint64_t start, uint64_t end) const { const std::vector tokensVec(tokens.begin(), tokens.end()); const std::string segmentText = tokenizer_->decode(tokensVec, true); @@ -425,8 +396,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, while (iss >> word) { // Detect special tokens such as [BLANK_AUDIO] by searching for square // bracket. - if (word.find('[') == std::string::npos && - word.find(']') == std::string::npos) { + if (word.find('[') == std::string::npos && word.find(']') == std::string::npos) { wordsStr.emplace_back(" "); wordsStr.back().append(word); } @@ -438,8 +408,7 @@ ASR::estimateWordLevelTimestampsLinear(std::span tokens, } const float duration = (end - start) * constants::kTimePrecision; const float timePerChar = duration / std::max(1, numChars); - const float startOffset = - (start - timestampBeginToken_) * constants::kTimePrecision; + const float startOffset = (start - timestampBeginToken_) * constants::kTimePrecision; std::vector wordObjs; wordObjs.reserve(wordsStr.size()); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h index 475774f2e8..5a06973226 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/ASR.h @@ -36,9 +36,8 @@ class ASR : public models::BaseModel, public schema::ASR { * encode's input. * @param options Control variables for decoding process. */ - std::vector virtual transcribe( - std::span waveform, - const DecodingOptions &options) const override; + std::vector virtual transcribe(std::span waveform, + const DecodingOptions &options) const override; /** * Encodes the input audio waveform into mel spectrogram embeddings. @@ -48,8 +47,7 @@ class ASR : public models::BaseModel, public schema::ASR { * The output tensor shape: [1, 1500, 384] for Whisper * models. */ - executorch::aten::Tensor - encode(std::span waveform) const override; + executorch::aten::Tensor encode(std::span waveform) const override; /** * Decodes a sequence of tokens into logits given the encoded audio features. @@ -63,8 +61,7 @@ class ASR : public models::BaseModel, public schema::ASR { * @return A tensor representing the output logits for the next * token. */ - executorch::aten::Tensor decode(std::span tokens, - std::span encoderOutput, + executorch::aten::Tensor decode(std::span tokens, std::span encoderOutput, uint64_t startPos = 0) const override; // Standard ExecuTorch model methods for compatibility with the rest of the @@ -84,8 +81,7 @@ class ASR : public models::BaseModel, public schema::ASR { * such as whether * to add a language mark token or not. */ - std::vector - createInitialSequence(const DecodingOptions &options) const; + std::vector createInitialSequence(const DecodingOptions &options) const; /** * Generation wrapper - wrapps encoding & decoding with @@ -114,10 +110,9 @@ class ASR : public models::BaseModel, public schema::ASR { * @param encoderOutput An optional parameter. If provided, the encoding phase * is skipped and the provided value is used instead. */ - GenerationResult generate( - std::span waveform, const DecodingOptions &options, - float temperature, - std::optional> encoderOutput = std::nullopt) const; + GenerationResult + generate(std::span waveform, const DecodingOptions &options, float temperature, + std::optional> encoderOutput = std::nullopt) const; /** * Calculates word-level timestamps for a sequence of generated tokens. @@ -134,11 +129,10 @@ class ASR : public models::BaseModel, public schema::ASR { * @return A vector of transcribed segments with word-level * timing. */ - std::vector - calculateWordLevelTimestamps(std::span generatedTokens, - const std::span waveform, - float avgLogProb, float temperature, - float compressionRatio) const; + std::vector calculateWordLevelTimestamps(std::span generatedTokens, + const std::span waveform, + float avgLogProb, float temperature, + float compressionRatio) const; /** * Estimates word-level timestamps linearly within a token sequence. @@ -151,9 +145,8 @@ class ASR : public models::BaseModel, public schema::ASR { * @param end The timestamp token ID marking the end of the segment. * @return A vector of Word objects with estimated start and end times. */ - std::vector - estimateWordLevelTimestampsLinear(std::span tokens, - uint64_t start, uint64_t end) const; + std::vector estimateWordLevelTimestampsLinear(std::span tokens, + uint64_t start, uint64_t end) const; float calculateCompressionRatio(const std::string &text) const; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h index 62a9f968f7..83d53d39c9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Constants.h @@ -30,8 +30,7 @@ inline constexpr size_t kNumFrames = 1500; inline constexpr float kTimePrecision = 0.02f; // Special characters serving as pause / end of sentence -inline const std::unordered_set kPunctations = {',', '.', '?', - '!', ':', ';'}; +inline const std::unordered_set kPunctations = {',', '.', '?', '!', ':', ';'}; inline const std::unordered_set kEosPunctations = {'.', '?', '!', ';'}; // Special token constants @@ -42,4 +41,4 @@ inline const std::string kBeginTimestamp = "<|0.00|>"; inline const std::string kBlankAudio = "[BLANK_AUDIO]"; } // namespace tokens -} // namespace rnexecutorch::models::speech_to_text::whisper::constants \ No newline at end of file +} // namespace rnexecutorch::models::speech_to_text::whisper::constants diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp index 963567d8b5..9d9ba69524 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.cpp @@ -10,8 +10,7 @@ namespace rnexecutorch::models::speech_to_text::whisper::stream { -OnlineASR::OnlineASR(const ASR *asr, const VoiceActivityDetection *vad) - : asr_(asr), vad_(vad) { +OnlineASR::OnlineASR(const ASR *asr, const VoiceActivityDetection *vad) : asr_(asr), vad_(vad) { audioBuffer_.reserve((constants::kChunkSize + 1) * constants::kSamplingRate); } @@ -44,10 +43,10 @@ void OnlineASR::insertAudioChunk(std::span audio) { } ProcessResult OnlineASR::process(const StreamingOptions &options) { - constexpr size_t kStreamSafeBufferMaxSamples = static_cast( - params::kStreamSafeBufferDuration * constants::kSamplingRate); - constexpr size_t kSafetyMarginSamples = static_cast( - params::kStreamSafetyThreshold * constants::kSamplingRate); + constexpr size_t kStreamSafeBufferMaxSamples = + static_cast(params::kStreamSafeBufferDuration * constants::kSamplingRate); + constexpr size_t kSafetyMarginSamples = + static_cast(params::kStreamSafetyThreshold * constants::kSamplingRate); std::vector audioCopy; @@ -63,15 +62,14 @@ ProcessResult OnlineASR::process(const StreamingOptions &options) { // Allowing VAD changes logic significantly - we no longer commit and clean // at max samples reached moments, but rather at the end of speech moments. if (options.useVAD && vad_) { - auto speechSegments = vad_->generate(audioCopy, options.vadDetectionMargin * - params::kVadGapFactor); + auto speechSegments = + vad_->generate(audioCopy, options.vadDetectionMargin * params::kVadGapFactor); if (speechSegments.empty()) { // Extra cleanup to speed-up future processing by removing silence. if (audioCopy.size() > params::kVadDeadSamplesRemovalSamples) { std::scoped_lock lock(streamingMutex); - size_t cut = std::min(params::kVadDeadSamplesRemovalSamples - - kSafetyMarginSamples, + size_t cut = std::min(params::kVadDeadSamplesRemovalSamples - kSafetyMarginSamples, audioBuffer_.size()); audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + cut); } @@ -80,17 +78,14 @@ ProcessResult OnlineASR::process(const StreamingOptions &options) { } const auto &lastSegment = speechSegments.back(); - size_t marginSamples = - options.vadDetectionMargin * constants::kSamplesPerMilisecond; + size_t marginSamples = options.vadDetectionMargin * constants::kSamplesPerMilisecond; if (audioCopy.size() - lastSegment.end <= marginSamples) { // Speech is ongoing. Keep last 1s context and trim around current // segment. size_t startWithMargin = - std::max(lastSegment.start, constants::kSamplingRate) - - constants::kSamplingRate; - input = std::span(audioCopy.begin() + startWithMargin, - audioCopy.begin() + lastSegment.end); + std::max(lastSegment.start, constants::kSamplingRate) - constants::kSamplingRate; + input = std::span(audioCopy.begin() + startWithMargin, audioCopy.begin() + lastSegment.end); } else { // Speech ended beyond margin. Commit existing transcript and clear // buffer. @@ -100,14 +95,12 @@ ProcessResult OnlineASR::process(const StreamingOptions &options) { memory_.eos.clear(); audioBuffer_.erase(audioBuffer_.begin(), - audioBuffer_.begin() + - std::min(lastSegment.end, audioBuffer_.size())); + audioBuffer_.begin() + std::min(lastSegment.end, audioBuffer_.size())); return {.committed = std::move(committed), .nonCommitted = {}}; } } else { input = std::span(audioCopy.begin(), - audioCopy.begin() + - std::min(constants::kMaxSamples, audioCopy.size())); + audioCopy.begin() + std::min(constants::kMaxSamples, audioCopy.size())); } std::vector transcriptions = asr_->transcribe(input, options); @@ -127,8 +120,7 @@ ProcessResult OnlineASR::process(const StreamingOptions &options) { // due to model correcting it's output. for (auto it = memory_.eos.begin(); it != memory_.eos.end(); it++) { if (it->position >= words.size() || !utils::isEos(words[it->position]) || - (it->position > 0 && - it->preceeding != words[it->position - 1].content)) { + (it->position > 0 && it->preceeding != words[it->position - 1].content)) { memory_.eos.erase(it, memory_.eos.end()); break; } @@ -143,8 +135,7 @@ ProcessResult OnlineASR::process(const StreamingOptions &options) { // Because of step 1, we know that if the last EOS exist in eos_, // then it must be the last entry. if (memory_.eos.empty() || memory_.eos.back().position != lastEosIndex) { - std::string preceeding = - lastEosIndex > 0 ? words[lastEosIndex - 1].content : ""; + std::string preceeding = lastEosIndex > 0 ? words[lastEosIndex - 1].content : ""; memory_.eos.emplace_back(lastEosIndex, preceeding, lastEosIt->end); } } @@ -154,8 +145,7 @@ ProcessResult OnlineASR::process(const StreamingOptions &options) { // Step 3: collect all the words which could possible get committed // in-between iterations. if (!memory_.toCommit.empty()) { - committed.insert(committed.end(), - std::make_move_iterator(memory_.toCommit.begin()), + committed.insert(committed.end(), std::make_move_iterator(memory_.toCommit.begin()), std::make_move_iterator(memory_.toCommit.end())); memory_.toCommit.clear(); } @@ -168,8 +158,7 @@ ProcessResult OnlineASR::process(const StreamingOptions &options) { if (bufferSize > kStreamSafeBufferMaxSamples) { auto newCommitted = commitAndClean(words); - committed.insert(committed.end(), - std::make_move_iterator(newCommitted.begin()), + committed.insert(committed.end(), std::make_move_iterator(newCommitted.begin()), std::make_move_iterator(newCommitted.end())); } @@ -189,8 +178,7 @@ std::vector OnlineASR::finish(const StreamingOptions &options) { // Last-tick committed delta + whatever never made it past the commit // threshold. std::vector residual{std::move(result.committed)}; - residual.insert(residual.end(), - std::make_move_iterator(result.nonCommitted.begin()), + residual.insert(residual.end(), std::make_move_iterator(result.nonCommitted.begin()), std::make_move_iterator(result.nonCommitted.end())); reset(); @@ -213,8 +201,8 @@ std::vector OnlineASR::commitAndClean(std::vector &transcript) { constexpr float kMidpointAnchorTime = params::kStreamMaxDuration / 2.0F; constexpr size_t kMidpointAnchorSamples = static_cast(kMidpointAnchorTime * constants::kSamplingRate); - constexpr size_t kSafetyMarginSamples = static_cast( - params::kStreamSafetyThreshold * constants::kSamplingRate); + constexpr size_t kSafetyMarginSamples = + static_cast(params::kStreamSafetyThreshold * constants::kSamplingRate); constexpr float kMaxSafeEosTime = params::kStreamSafeBufferDuration - params::kStreamSafetyThreshold; constexpr float kMinDurationToCalculateDensity = 0.1F; @@ -239,12 +227,10 @@ std::vector OnlineASR::commitAndClean(std::vector &transcript) { else if (memory_.eos.size() == 1) { const float eosTimestamp = memory_.eos[0].tmstpend; - const float upperHalfDuration = - std::max(0.0F, eosTimestamp - kMidpointAnchorTime); - const float wordsPerSecond = - upperHalfDuration > kMinDurationToCalculateDensity - ? static_cast(transcript.size()) / upperHalfDuration - : 0.0F; + const float upperHalfDuration = std::max(0.0F, eosTimestamp - kMidpointAnchorTime); + const float wordsPerSecond = upperHalfDuration > kMinDurationToCalculateDensity + ? static_cast(transcript.size()) / upperHalfDuration + : 0.0F; // The EOS sits early enough that cutting up to the safety margin won't // touch the ongoing (post-EOS) speech. @@ -254,22 +240,18 @@ std::vector OnlineASR::commitAndClean(std::vector &transcript) { // EOS lies past the midpoint, but a low word density implies the spoken // audio is concentrated in the upper half. Drop the lower half and // shift the EOS accordingly. - audioBuffer_.erase(audioBuffer_.begin(), - audioBuffer_.begin() + kMidpointAnchorSamples); + audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + kMidpointAnchorSamples); memory_.eos[0].tmstpend -= kMidpointAnchorTime; } else { // Cut everything up to and including the sentence — either by the // safety margin (when EOS is early) or (more aggresively) right at the // EOS boundary — and commit its words. - const size_t cut = - eosSafe - ? bufferSize - kSafetyMarginSamples - : static_cast(eosTimestamp * constants::kSamplingRate); + const size_t cut = eosSafe ? bufferSize - kSafetyMarginSamples + : static_cast(eosTimestamp * constants::kSamplingRate); audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + cut); - committed.insert(committed.end(), - std::make_move_iterator(transcript.begin()), + committed.insert(committed.end(), std::make_move_iterator(transcript.begin()), std::make_move_iterator(transcript.end())); transcript.clear(); @@ -282,17 +264,14 @@ std::vector OnlineASR::commitAndClean(std::vector &transcript) { else { const auto &secondTolastEntry = memory_.eos[memory_.eos.size() - 2]; - const size_t cut = static_cast(secondTolastEntry.tmstpend * - constants::kSamplingRate); + const size_t cut = static_cast(secondTolastEntry.tmstpend * constants::kSamplingRate); const size_t lastCommittedPos = secondTolastEntry.position; audioBuffer_.erase(audioBuffer_.begin(), audioBuffer_.begin() + cut); - committed.insert( - committed.end(), std::make_move_iterator(transcript.begin()), - std::make_move_iterator(transcript.begin() + lastCommittedPos + 1)); - transcript.erase(transcript.begin(), - transcript.begin() + lastCommittedPos + 1); + committed.insert(committed.end(), std::make_move_iterator(transcript.begin()), + std::make_move_iterator(transcript.begin() + lastCommittedPos + 1)); + transcript.erase(transcript.begin(), transcript.begin() + lastCommittedPos + 1); // Retain only the most recent EOS entry, shifting both its timestamp // and its position to match the new (truncated) transcript origin. diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h index f13fa9ccff..3e334a0478 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/OnlineASR.h @@ -74,18 +74,16 @@ class OnlineASR : public schema::OnlineASR { struct Memory { // State management helper. struct EOSEntry { - size_t position; // An absolute position (index) in the transcription - // (word sequence). + size_t position; // An absolute position (index) in the transcription + // (word sequence). std::string preceeding; // A preceeding word in the transcription float tmstpend; // Ending timestamp of the sentence. }; - std::vector - transcript; // The most recent transcription result (uncommitted only!). - std::vector - eos; // End of sentence points from the most recent transcription. - std::vector toCommit; // Words to be committed in the next iteration - // (next process() call). + std::vector transcript; // The most recent transcription result (uncommitted only!). + std::vector eos; // End of sentence points from the most recent transcription. + std::vector toCommit; // Words to be committed in the next iteration + // (next process() call). } memory_; }; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h index f6c6e491b0..556f34199c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Params.h @@ -17,8 +17,7 @@ namespace rnexecutorch::models::speech_to_text::whisper::params { * Maximum duration of audio that the streaming buffer keeps before forcing * a cleanup. Aligned with Whisper's maximum supported input length. */ -constexpr inline float kStreamMaxDuration = - static_cast(constants::kChunkSize); +constexpr inline float kStreamMaxDuration = static_cast(constants::kChunkSize); /** * The minimum amount of recent audio always kept in the buffer when a blind @@ -75,8 +74,7 @@ constexpr inline size_t kVadDeadSamplesRemovalSamples = // two ever invert, that subtraction wraps and the subsequent erase reads past // the buffer. Catch the regression at compile time. static_assert(kVadDeadSamplesRemovalSamples > - static_cast(kStreamSafetyThreshold * - constants::kSamplingRate), + static_cast(kStreamSafetyThreshold * constants::kSamplingRate), "kVadDeadSamplesRemovalSamples must exceed the safety margin"); } // namespace rnexecutorch::models::speech_to_text::whisper::params diff --git a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h index ae461c27cf..9e28c1bd07 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/speech_to_text/whisper/Utils.h @@ -15,8 +15,7 @@ namespace rnexecutorch::models::speech_to_text::whisper::utils { * @param word The word to check. */ inline bool isEos(const Word &word) { - return word.content.size() == 1 && - constants::kEosPunctations.contains(word.content[0]); + return word.content.size() == 1 && constants::kEosPunctations.contains(word.content[0]); } -} // namespace rnexecutorch::models::speech_to_text::whisper::utils \ No newline at end of file +} // namespace rnexecutorch::models::speech_to_text::whisper::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp index 184d2db10d..72e032a370 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.cpp @@ -28,8 +28,7 @@ StyleTransfer::StyleTransfer(const std::string &modelSource, "Unexpected model input size, expected at least 2 dimensions " "but got: %zu.", modelInputShape_.size()); - throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, - errorMessage); + throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, errorMessage); } } @@ -38,14 +37,13 @@ cv::Mat StyleTransfer::runInference(cv::Mat image, cv::Size outputSize) { cv::Mat preprocessed = preprocess(image); - auto inputTensor = - image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); + auto inputTensor = image_processing::getTensorFromMatrix(modelInputShape_, preprocessed); auto forwardResult = BaseModel::forward(inputTensor); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); - cv::Mat mat = image_processing::getMatrixFromTensor( - modelInputSize(), forwardResult->at(0).toTensor()); + cv::Mat mat = + image_processing::getMatrixFromTensor(modelInputSize(), forwardResult->at(0).toTensor()); if (mat.size() != outputSize) { cv::resize(mat, mat, outputSize); } @@ -62,8 +60,7 @@ PixelDataResult toPixelDataResult(const cv::Mat &bgrMat) { return PixelDataResult{pixelBuffer, size.width, size.height, rgba.channels()}; } -StyleTransferResult StyleTransfer::generateFromString(std::string imageSource, - bool saveToFile) { +StyleTransferResult StyleTransfer::generateFromString(std::string imageSource, bool saveToFile) { cv::Mat imageBGR = image_processing::readImage(imageSource); cv::Size originalSize = imageBGR.size(); @@ -87,8 +84,7 @@ PixelDataResult StyleTransfer::generateFromFrame(jsi::Runtime &runtime, return toPixelDataResult(oriented); } -StyleTransferResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData, - bool saveToFile) { +StyleTransferResult StyleTransfer::generateFromPixels(JSTensorViewIn pixelData, bool saveToFile) { cv::Mat image = extractFromPixels(pixelData); cv::Mat result = runInference(image, image.size()); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h index c15095bf5b..6459e27dba 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/style_transfer/StyleTransfer.h @@ -19,8 +19,7 @@ using namespace facebook; class StyleTransfer : public VisionModel { public: - StyleTransfer(const std::string &modelSource, - std::shared_ptr callInvoker); + StyleTransfer(const std::string &modelSource, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] StyleTransferResult generateFromString(std::string imageSource, bool saveToFile); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp index e3e37521ee..1e44970573 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp @@ -10,21 +10,17 @@ namespace rnexecutorch::models::text_to_image { using namespace executorch::extension; -Decoder::Decoder(const std::string &modelSource, - std::shared_ptr callInvoker) +Decoder::Decoder(const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) {} std::vector Decoder::generate(std::vector &input) const { - std::vector inputShape = {1, numChannels, latentImageSize, - latentImageSize}; - auto inputTensor = - make_tensor_ptr(inputShape, input.data(), ScalarType::Float); + std::vector inputShape = {1, numChannels, latentImageSize, latentImageSize}; + auto inputTensor = make_tensor_ptr(inputShape, input.data(), ScalarType::Float); auto forwardResult = BaseModel::forward(inputTensor); if (!forwardResult.ok()) { - throw RnExecutorchError( - forwardResult.error(), - "Function forward in decoder failed with error code: "); + throw RnExecutorchError(forwardResult.error(), + "Function forward in decoder failed with error code: "); } auto forwardResultTensor = forwardResult->at(0).toTensor(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h index c0b35c102a..176536df34 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h @@ -12,8 +12,7 @@ namespace rnexecutorch::models::text_to_image { class Decoder final : public BaseModel { public: - explicit Decoder(const std::string &modelSource, - std::shared_ptr callInvoker); + explicit Decoder(const std::string &modelSource, std::shared_ptr callInvoker); std::vector generate(std::vector &input) const; int32_t latentImageSize; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp index 68a9a9fef4..f09ea8cd6f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp @@ -8,12 +8,10 @@ namespace rnexecutorch::models::text_to_image { -Encoder::Encoder(const std::string &tokenizerSource, - const std::string &encoderSource, +Encoder::Encoder(const std::string &tokenizerSource, const std::string &encoderSource, std::shared_ptr callInvoker) - : callInvoker(callInvoker), - encoder(std::make_unique( - encoderSource, tokenizerSource, callInvoker)) {} + : callInvoker(callInvoker), encoder(std::make_unique( + encoderSource, tokenizerSource, callInvoker)) {} std::vector Encoder::generate(std::string input) { std::shared_ptr embeddingsText = encoder->generate(input); @@ -23,8 +21,7 @@ std::vector Encoder::generate(std::string input) { assert(embeddingsText->size() == embeddingsUncond->size()); size_t embeddingsSize = embeddingsText->size() / sizeof(float); auto *embeddingsTextPtr = reinterpret_cast(embeddingsText->data()); - auto *embeddingsUncondPtr = - reinterpret_cast(embeddingsUncond->data()); + auto *embeddingsUncondPtr = reinterpret_cast(embeddingsUncond->data()); std::vector embeddingsConcat; embeddingsConcat.reserve(embeddingsSize * 2); @@ -35,9 +32,7 @@ std::vector Encoder::generate(std::string input) { return embeddingsConcat; } -size_t Encoder::getMemoryLowerBound() const noexcept { - return encoder->getMemoryLowerBound(); -} +size_t Encoder::getMemoryLowerBound() const noexcept { return encoder->getMemoryLowerBound(); } void Encoder::unload() noexcept { encoder->unload(); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.h index b444f30ab9..04db8e2384 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.h @@ -17,8 +17,7 @@ using namespace facebook; class Encoder final { public: - explicit Encoder(const std::string &tokenizerSource, - const std::string &encoderSource, + explicit Encoder(const std::string &tokenizerSource, const std::string &encoderSource, std::shared_ptr callInvoker); std::vector generate(std::string input); size_t getMemoryLowerBound() const noexcept; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.cpp index 61640f7f64..0722e92f04 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.cpp @@ -11,8 +11,7 @@ namespace rnexecutorch::models::text_to_image { using namespace facebook; -Scheduler::Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps, - int32_t stepsOffset, +Scheduler::Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps, int32_t stepsOffset, std::shared_ptr callInvoker) : numTrainTimesteps(numTrainTimesteps), stepsOffset(stepsOffset) { const float start = std::sqrt(betaStart); @@ -56,11 +55,9 @@ void Scheduler::setTimesteps(size_t numInferenceSteps) { timesteps.clear(); timesteps.reserve(numInferenceSteps + 1); - float numStepsRatio = - static_cast(numTrainTimesteps) / numInferenceSteps; + float numStepsRatio = static_cast(numTrainTimesteps) / numInferenceSteps; for (size_t i = 0; i < numInferenceSteps; i++) { - const auto timestep = - static_cast(std::round(i * numStepsRatio)) + stepsOffset; + const auto timestep = static_cast(std::round(i * numStepsRatio)) + stepsOffset; timesteps.push_back(timestep); } // Duplicate the timestep to provide enough points for the solver @@ -69,18 +66,15 @@ void Scheduler::setTimesteps(size_t numInferenceSteps) { } std::vector Scheduler::step(const std::vector &sample, - const std::vector &noise, - int32_t timestep) { + const std::vector &noise, int32_t timestep) { if (numInferenceSteps == 0) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidConfig, - "Number of inference steps is not set. Call `set_timesteps` first."); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidConfig, + "Number of inference steps is not set. Call `set_timesteps` first."); } size_t noiseSize = noise.size(); std::vector etsOutput(noiseSize); - float numStepsRatio = - static_cast(numTrainTimesteps) / numInferenceSteps; + float numStepsRatio = static_cast(numTrainTimesteps) / numInferenceSteps; float timestepPrev = timestep - numStepsRatio; if (ets.empty()) { @@ -95,8 +89,8 @@ std::vector Scheduler::step(const std::vector &sample, for (size_t i = 0; i < noiseSize; i++) { etsOutput[i] = (noise[i] + ets[0][i]) / 2; } - auto prevSample = getPrevSample(std::move(tempFirstSample), etsOutput, - timestep + numStepsRatio, timestep); + auto prevSample = + getPrevSample(std::move(tempFirstSample), etsOutput, timestep + numStepsRatio, timestep); tempFirstSample.clear(); return prevSample; } @@ -116,36 +110,30 @@ std::vector Scheduler::step(const std::vector &sample, } else { ets.assign(ets.end() - 4, ets.end()); for (size_t i = 0; i < noiseSize; i++) { - etsOutput[i] = - (ets[3][i] * 55 - ets[2][i] * 59 + ets[1][i] * 37 - ets[0][i] * 9) / - 24; + etsOutput[i] = (ets[3][i] * 55 - ets[2][i] * 59 + ets[1][i] * 37 - ets[0][i] * 9) / 24; } } return getPrevSample(sample, etsOutput, timestep, timestepPrev); } std::vector Scheduler::getPrevSample(const std::vector &sample, - const std::vector &noise, - int32_t timestep, + const std::vector &noise, int32_t timestep, int32_t timestepPrev) const { const float alpha = alphasCumprod[timestep]; - const float alphaPrev = - timestepPrev >= 0 ? alphasCumprod[timestepPrev] : finalAlphaCumprod; + const float alphaPrev = timestepPrev >= 0 ? alphasCumprod[timestepPrev] : finalAlphaCumprod; const float beta = 1 - alpha; const float betaPrev = 1 - alphaPrev; size_t noiseSize = noise.size(); const float noiseCoeff = - (alphaPrev - alpha) / - (alpha * std::sqrt(betaPrev) + std::sqrt(alpha * beta * alphaPrev)); + (alphaPrev - alpha) / (alpha * std::sqrt(betaPrev) + std::sqrt(alpha * beta * alphaPrev)); const float sampleCoeff = std::sqrt(alphaPrev / alpha); std::vector samplePrev; samplePrev.reserve(noiseSize); for (size_t i = 0; i < noiseSize; i++) { const float noiseTerm = - (noise[i] * std::sqrt(alpha) + sample[i] * std::sqrt(beta)) * - noiseCoeff; + (noise[i] * std::sqrt(alpha) + sample[i] * std::sqrt(beta)) * noiseCoeff; samplePrev.push_back(sample[i] * sampleCoeff - noiseTerm); } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.h index 66fe422c3b..f40334e34a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.h @@ -11,12 +11,11 @@ using namespace facebook; class Scheduler final { public: - explicit Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps, - int32_t stepsOfset, + explicit Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps, int32_t stepsOfset, std::shared_ptr callInvoker); void setTimesteps(size_t numInferenceSteps); - std::vector step(const std::vector &sample, - const std::vector &noise, int32_t timestep); + std::vector step(const std::vector &sample, const std::vector &noise, + int32_t timestep); std::vector timesteps; @@ -34,8 +33,7 @@ class Scheduler final { size_t numInferenceSteps{0}; std::vector getPrevSample(const std::vector &sample, - const std::vector &noise, - int32_t timestep, + const std::vector &noise, int32_t timestep, int32_t prevTimestep) const; }; } // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp index d96f76c8ec..a1192000e8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp @@ -18,20 +18,16 @@ namespace rnexecutorch::models::text_to_image { using namespace executorch::extension; -TextToImage::TextToImage(const std::string &tokenizerSource, - const std::string &encoderSource, - const std::string &unetSource, - const std::string &decoderSource, +TextToImage::TextToImage(const std::string &tokenizerSource, const std::string &encoderSource, + const std::string &unetSource, const std::string &decoderSource, float schedulerBetaStart, float schedulerBetaEnd, - int32_t schedulerNumTrainTimesteps, - int32_t schedulerStepsOffset, + int32_t schedulerNumTrainTimesteps, int32_t schedulerStepsOffset, std::shared_ptr callInvoker) : callInvoker(callInvoker), - scheduler(std::make_unique( - schedulerBetaStart, schedulerBetaEnd, schedulerNumTrainTimesteps, - schedulerStepsOffset, callInvoker)), - encoder(std::make_unique(tokenizerSource, encoderSource, - callInvoker)), + scheduler(std::make_unique(schedulerBetaStart, schedulerBetaEnd, + schedulerNumTrainTimesteps, schedulerStepsOffset, + callInvoker)), + encoder(std::make_unique(tokenizerSource, encoderSource, callInvoker)), unet(std::make_unique(unetSource, callInvoker)), decoder(std::make_unique(decoderSource, callInvoker)) {} @@ -56,17 +52,15 @@ void TextToImage::setSeed(int32_t &seed) { seed = rd(); } -std::string TextToImage::generate(std::string input, int32_t imageSize, - size_t numInferenceSteps, int32_t seed, - std::shared_ptr callback) { +std::string TextToImage::generate(std::string input, int32_t imageSize, size_t numInferenceSteps, + int32_t seed, std::shared_ptr callback) { std::scoped_lock lock(inference_mutex_); setImageSize(imageSize); setSeed(seed); std::vector embeddings = encoder->generate(input); std::vector embeddingsShape = {2, 77, 768}; - auto embeddingsTensor = - make_tensor_ptr(embeddingsShape, embeddings.data(), ScalarType::Float); + auto embeddingsTensor = make_tensor_ptr(embeddingsShape, embeddings.data(), ScalarType::Float); int32_t latentsSize = numChannels * latentImageSize * latentImageSize; std::vector latents(latentsSize); @@ -87,18 +81,15 @@ std::string TextToImage::generate(std::string input, int32_t imageSize, for (size_t t = 0; t < numInferenceSteps + 1 && !interrupted; t++) { log(LOG_LEVEL::Debug, "Step:", t, "/", numInferenceSteps); - std::vector noisePred = - unet->generate(latents, timesteps[t], embeddingsTensor); + std::vector noisePred = unet->generate(latents, timesteps[t], embeddingsTensor); size_t noiseSize = noisePred.size() / 2; std::span noisePredSpan{noisePred}; std::span noiseUncond = noisePredSpan.subspan(0, noiseSize); - std::span noiseText = - noisePredSpan.subspan(noiseSize, noiseSize); + std::span noiseText = noisePredSpan.subspan(noiseSize, noiseSize); std::vector noise(noiseSize); for (size_t i = 0; i < noiseSize; i++) { - noise[i] = - noiseUncond[i] * (1 - guidanceScale) + noiseText[i] * guidanceScale; + noise[i] = noiseUncond[i] * (1 - guidanceScale) + noiseText[i] * guidanceScale; } latents = scheduler->step(latents, noise, timesteps[t]); @@ -125,9 +116,9 @@ std::string TextToImage::postprocess(const std::vector &output) const { auto *row = bgr.ptr(y); for (int32_t x = 0; x < imageSize; ++x) { const int32_t idx = (y * imageSize + x) * 3; - row[x] = cv::Vec3b(static_cast(output[idx + 2]), - static_cast(output[idx + 1]), - static_cast(output[idx + 0])); + row[x] = + cv::Vec3b(static_cast(output[idx + 2]), static_cast(output[idx + 1]), + static_cast(output[idx + 0])); } } return image_processing::saveToTempFile(bgr); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h index 5eaae65f08..4b8325fc10 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h @@ -21,16 +21,12 @@ using namespace facebook; class TextToImage final { public: - explicit TextToImage(const std::string &tokenizerSource, - const std::string &encoderSource, - const std::string &unetSource, - const std::string &decoderSource, + explicit TextToImage(const std::string &tokenizerSource, const std::string &encoderSource, + const std::string &unetSource, const std::string &decoderSource, float schedulerBetaStart, float schedulerBetaEnd, - int32_t schedulerNumTrainTimesteps, - int32_t schedulerStepsOffset, + int32_t schedulerNumTrainTimesteps, int32_t schedulerStepsOffset, std::shared_ptr callInvoker); - std::string generate(std::string input, int32_t imageSize, - size_t numInferenceSteps, int32_t seed, + std::string generate(std::string input, int32_t imageSize, size_t numInferenceSteps, int32_t seed, std::shared_ptr callback); void interrupt() noexcept; size_t getMemoryLowerBound() const noexcept; @@ -58,7 +54,7 @@ class TextToImage final { }; } // namespace models::text_to_image -REGISTER_CONSTRUCTOR(models::text_to_image::TextToImage, std::string, - std::string, std::string, std::string, float, float, - int32_t, int32_t, std::shared_ptr); +REGISTER_CONSTRUCTOR(models::text_to_image::TextToImage, std::string, std::string, std::string, + std::string, float, float, int32_t, int32_t, + std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp index d2d58badb8..47bbf95508 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp @@ -6,8 +6,7 @@ namespace rnexecutorch::models::text_to_image { using namespace executorch::extension; -UNet::UNet(const std::string &modelSource, - std::shared_ptr callInvoker) +UNet::UNet(const std::string &modelSource, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker) {} std::vector UNet::generate(std::vector &latents, int32_t timestep, @@ -17,16 +16,12 @@ std::vector UNet::generate(std::vector &latents, int32_t timestep, latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end()); latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end()); - std::vector latentsShape = {2, numChannels, latentImageSize, - latentImageSize}; + std::vector latentsShape = {2, numChannels, latentImageSize, latentImageSize}; - auto timestepTensor = - make_tensor_ptr({static_cast(timestep)}); - auto latentsTensor = - make_tensor_ptr(latentsShape, latentsConcat.data(), ScalarType::Float); + auto timestepTensor = make_tensor_ptr({static_cast(timestep)}); + auto latentsTensor = make_tensor_ptr(latentsShape, latentsConcat.data(), ScalarType::Float); - auto forwardResult = - BaseModel::forward({latentsTensor, timestepTensor, embeddingsTensor}); + auto forwardResult = BaseModel::forward({latentsTensor, timestepTensor, embeddingsTensor}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); auto forwardResultTensor = forwardResult->at(0).toTensor(); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h index 0c6dd057c6..f870275a84 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h @@ -15,8 +15,7 @@ using namespace executorch::extension; class UNet final : public BaseModel { public: - explicit UNet(const std::string &modelSource, - std::shared_ptr callInvoker); + explicit UNet(const std::string &modelSource, std::shared_ptr callInvoker); std::vector generate(std::vector &latents, int32_t timestep, TensorPtr &embeddingsTensor) const; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/TextToSpeech.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/TextToSpeech.h index 2fb17c95f5..fe426fe5f6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/TextToSpeech.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/TextToSpeech.h @@ -1,3 +1,3 @@ #pragma once -#include "kokoro/Kokoro.h" \ No newline at end of file +#include "kokoro/Kokoro.h" diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Constants.h index 21f0e11501..d23b13bcc0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Constants.h @@ -14,12 +14,10 @@ inline constexpr size_t kMinInputTokens = 8; // Models do not accept less amount of tokens (including padding) inline constexpr size_t kMaxInputTokens = 128; // Models do not accept more tokens (including padding) -inline constexpr size_t kMinDurationTicks = - 16; // Corresponds to DurationPredictor output and one of Synthesizer's - // input shapes -inline constexpr size_t kMaxDurationTicks = - 296; // Corresponds to DurationPredictor output and one of Synthesizer's - // input shapes +inline constexpr size_t kMinDurationTicks = 16; // Corresponds to DurationPredictor output and one + // of Synthesizer's input shapes +inline constexpr size_t kMaxDurationTicks = 296; // Corresponds to DurationPredictor output and one + // of Synthesizer's input shapes inline constexpr float kMinValidSpeed = 0.1F; inline constexpr float kMaxValidSpeed = 3.0F; @@ -32,8 +30,7 @@ inline constexpr int32_t kVoiceRefHalfSize = kVoiceRefSize / 2; // This corresponds to a number of elements in resulting audio vector per each // duration point. inline constexpr int32_t kTicksPerDuration = 600; -inline constexpr int32_t kSamplingRate = - 24000; // Corresponds to Kokoro's model audio frequency +inline constexpr int32_t kSamplingRate = 24000; // Corresponds to Kokoro's model audio frequency inline constexpr int32_t kSamplesPerMilisecond = kSamplingRate / 1000; // Special text characters @@ -57,32 +54,28 @@ inline const std::unordered_set kPauseCharacters = { // Phoneme to token mappings inline constexpr int32_t kVocabSize = 178; inline const std::unordered_map kVocab = { - {U';', 1}, {U':', 2}, {U',', 3}, {U'.', 4}, {U'!', 5}, - {U'?', 6}, {U'—', 9}, {U'…', 10}, {U'"', 11}, {U'(', 12}, - {U')', 13}, {U'“', 14}, {U'”', 15}, {U' ', 16}, {U'\u0303', 17}, - {U'ʣ', 18}, {U'ʥ', 19}, {U'ʦ', 20}, {U'ʨ', 21}, {U'ᵝ', 22}, - {U'\uAB67', 23}, {U'A', 24}, {U'I', 25}, {U'O', 31}, {U'Q', 33}, - {U'S', 35}, {U'T', 36}, {U'W', 39}, {U'Y', 41}, {U'ᵊ', 42}, - {U'a', 43}, {U'b', 44}, {U'c', 45}, {U'd', 46}, {U'e', 47}, - {U'f', 48}, {U'h', 50}, {U'i', 51}, {U'j', 52}, {U'k', 53}, - {U'l', 54}, {U'm', 55}, {U'n', 56}, {U'o', 57}, {U'p', 58}, - {U'q', 59}, {U'r', 60}, {U's', 61}, {U't', 62}, {U'u', 63}, - {U'v', 64}, {U'w', 65}, {U'x', 66}, {U'y', 67}, {U'z', 68}, - {U'ɑ', 69}, {U'ɐ', 70}, {U'ɒ', 71}, {U'æ', 72}, {U'β', 75}, - {U'ɔ', 76}, {U'ɕ', 77}, {U'ç', 78}, {U'ɖ', 80}, {U'ð', 81}, - {U'ʤ', 82}, {U'ə', 83}, {U'ɚ', 85}, {U'ɛ', 86}, {U'ɜ', 87}, - {U'ɟ', 90}, {U'ɡ', 92}, {U'ɥ', 99}, {U'ɨ', 101}, {U'ɪ', 102}, - {U'ʝ', 103}, {U'ɯ', 110}, {U'ɰ', 111}, {U'ŋ', 112}, {U'ɳ', 113}, - {U'ɲ', 114}, {U'ɴ', 115}, {U'ø', 116}, {U'ɸ', 118}, {U'θ', 119}, - {U'œ', 120}, {U'ɹ', 123}, {U'ɾ', 125}, {U'ɻ', 126}, {U'ʁ', 128}, - {U'ɽ', 129}, {U'ʂ', 130}, {U'ʃ', 131}, {U'ʈ', 132}, {U'ʧ', 133}, - {U'ʊ', 135}, {U'ʋ', 136}, {U'ʌ', 138}, {U'ɣ', 139}, {U'ɤ', 140}, - {U'χ', 142}, {U'ʎ', 143}, {U'ʒ', 147}, {U'ʔ', 148}, {U'ˈ', 156}, - {U'ˌ', 157}, {U'ː', 158}, {U'ʰ', 162}, {U'ʲ', 164}, {U'↓', 169}, - {U'→', 171}, {U'↗', 172}, {U'↘', 173}, {U'ᵻ', 177}}; + {U';', 1}, {U':', 2}, {U',', 3}, {U'.', 4}, {U'!', 5}, {U'?', 6}, + {U'—', 9}, {U'…', 10}, {U'"', 11}, {U'(', 12}, {U')', 13}, {U'“', 14}, + {U'”', 15}, {U' ', 16}, {U'\u0303', 17}, {U'ʣ', 18}, {U'ʥ', 19}, {U'ʦ', 20}, + {U'ʨ', 21}, {U'ᵝ', 22}, {U'\uAB67', 23}, {U'A', 24}, {U'I', 25}, {U'O', 31}, + {U'Q', 33}, {U'S', 35}, {U'T', 36}, {U'W', 39}, {U'Y', 41}, {U'ᵊ', 42}, + {U'a', 43}, {U'b', 44}, {U'c', 45}, {U'd', 46}, {U'e', 47}, {U'f', 48}, + {U'h', 50}, {U'i', 51}, {U'j', 52}, {U'k', 53}, {U'l', 54}, {U'm', 55}, + {U'n', 56}, {U'o', 57}, {U'p', 58}, {U'q', 59}, {U'r', 60}, {U's', 61}, + {U't', 62}, {U'u', 63}, {U'v', 64}, {U'w', 65}, {U'x', 66}, {U'y', 67}, + {U'z', 68}, {U'ɑ', 69}, {U'ɐ', 70}, {U'ɒ', 71}, {U'æ', 72}, {U'β', 75}, + {U'ɔ', 76}, {U'ɕ', 77}, {U'ç', 78}, {U'ɖ', 80}, {U'ð', 81}, {U'ʤ', 82}, + {U'ə', 83}, {U'ɚ', 85}, {U'ɛ', 86}, {U'ɜ', 87}, {U'ɟ', 90}, {U'ɡ', 92}, + {U'ɥ', 99}, {U'ɨ', 101}, {U'ɪ', 102}, {U'ʝ', 103}, {U'ɯ', 110}, {U'ɰ', 111}, + {U'ŋ', 112}, {U'ɳ', 113}, {U'ɲ', 114}, {U'ɴ', 115}, {U'ø', 116}, {U'ɸ', 118}, + {U'θ', 119}, {U'œ', 120}, {U'ɹ', 123}, {U'ɾ', 125}, {U'ɻ', 126}, {U'ʁ', 128}, + {U'ɽ', 129}, {U'ʂ', 130}, {U'ʃ', 131}, {U'ʈ', 132}, {U'ʧ', 133}, {U'ʊ', 135}, + {U'ʋ', 136}, {U'ʌ', 138}, {U'ɣ', 139}, {U'ɤ', 140}, {U'χ', 142}, {U'ʎ', 143}, + {U'ʒ', 147}, {U'ʔ', 148}, {U'ˈ', 156}, {U'ˌ', 157}, {U'ː', 158}, {U'ʰ', 162}, + {U'ʲ', 164}, {U'↓', 169}, {U'→', 171}, {U'↗', 172}, {U'↘', 173}, {U'ᵻ', 177}}; // Special tokens inline constexpr Token kInvalidToken = -1; inline constexpr Token kPadToken = 0; -} // namespace rnexecutorch::models::text_to_speech::kokoro::constants \ No newline at end of file +} // namespace rnexecutorch::models::text_to_speech::kokoro::constants diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp index a3d27574bc..4c96e0ce05 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.cpp @@ -15,15 +15,13 @@ using ::executorch::aten::ScalarType; using ::executorch::extension::make_tensor_ptr; using ::executorch::extension::TensorPtr; -DurationPredictor::DurationPredictor( - const std::string &modelSource, const Context &modelContext, - std::shared_ptr callInvoker) +DurationPredictor::DurationPredictor(const std::string &modelSource, const Context &modelContext, + std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker), context_(modelContext) { auto availableMethods = module_->method_names(); if (!availableMethods.ok()) { - throw RnExecutorchError( - RnExecutorchErrorCode::UnknownError, - "[Kokoro::DurationPredictor] Unable to read model's methods"); + throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, + "[Kokoro::DurationPredictor] Unable to read model's methods"); } // Recognize available forward methods @@ -45,9 +43,8 @@ DurationPredictor::DurationPredictor( } // Sort the forward methods by input size - std::stable_sort( - forwardMethods_.begin(), forwardMethods_.end(), - [](const auto &a, const auto &b) { return a.second < b.second; }); + std::stable_sort(forwardMethods_.begin(), forwardMethods_.end(), + [](const auto &a, const auto &b) { return a.second < b.second; }); } std::tuple, int32_t, std::vector> @@ -62,35 +59,31 @@ DurationPredictor::generate(std::span tokens, std::span textMask, CHECK_SIZE(ref_hs, constants::kVoiceRefHalfSize); // Select appropriate forward method - auto it = - std::ranges::find_if(forwardMethods_, [inputSize](const auto &entry) { - return entry.second >= inputSize; - }); + auto it = std::ranges::find_if( + forwardMethods_, [inputSize](const auto &entry) { return entry.second >= inputSize; }); if (it == forwardMethods_.end()) { - throw RnExecutorchError( - RnExecutorchErrorCode::WrongDimensions, - "[Kokoro::DurationPredictor] No appropriate forward method to" - "handle input of size " + - std::to_string(inputSize)); + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, + "[Kokoro::DurationPredictor] No appropriate forward method to" + "handle input of size " + + std::to_string(inputSize)); } auto selectedMethod = it->first; // Convert input data to ExecuTorch tensors - auto tokensTensor = make_tensor_ptr({1, static_cast(tokens.size())}, - tokens.data(), ScalarType::Long); + auto tokensTensor = + make_tensor_ptr({1, static_cast(tokens.size())}, tokens.data(), ScalarType::Long); - auto textMaskTensor = - make_tensor_ptr({1, static_cast(textMask.size())}, - textMask.data(), ScalarType::Bool); + auto textMaskTensor = make_tensor_ptr({1, static_cast(textMask.size())}, textMask.data(), + ScalarType::Bool); - auto voiceRefTensor = make_tensor_ptr({1, constants::kVoiceRefHalfSize}, - ref_hs.data(), ScalarType::Float); + auto voiceRefTensor = + make_tensor_ptr({1, constants::kVoiceRefHalfSize}, ref_hs.data(), ScalarType::Float); auto speedTensor = make_tensor_ptr({1}, &speed, ScalarType::Float); // Execute the appropriate "forward_xyz" method, based on given method name - auto results = execute(selectedMethod, {tokensTensor, textMaskTensor, - voiceRefTensor, speedTensor}); + auto results = + execute(selectedMethod, {tokensTensor, textMaskTensor, voiceRefTensor, speedTensor}); CHECK_OK_OR_THROW_FORWARD_ERROR(results); @@ -101,8 +94,8 @@ DurationPredictor::generate(std::span tokens, std::span textMask, // Scale output durations if it exceedes the limits size_t totalDur = std::reduce(predDurPtr, predDurPtr + inputSize); - size_t clampedDur = std::clamp(totalDur, constants::kMinDurationTicks, - context_.inputDurationLimit); + size_t clampedDur = + std::clamp(totalDur, constants::kMinDurationTicks, context_.inputDurationLimit); if (totalDur != clampedDur) { scaleDurations(predDurTensor, inputSize, clampedDur); } @@ -111,23 +104,18 @@ DurationPredictor::generate(std::span tokens, std::span textMask, std::vector idxs(inputSize); std::iota(idxs.begin(), idxs.end(), 0LL); std::vector indices = rnexecutorch::sequential::repeatInterleave( - std::span(idxs), - std::span(predDurPtr, inputSize)); + std::span(idxs), std::span(predDurPtr, inputSize)); // Calculate the effective duration // Note that we lower effective duration even further, to remove // some of the side-effects at the end of the audio. int32_t originalLength = - std::distance(tokens.begin(), - std::find(tokens.begin() + 1, tokens.end(), 0)) + - 1; + std::distance(tokens.begin(), std::find(tokens.begin() + 1, tokens.end(), 0)) + 1; int32_t effDuration = std::distance( - indices.begin(), - std::lower_bound(indices.begin(), indices.end(), originalLength)); + indices.begin(), std::lower_bound(indices.begin(), indices.end(), originalLength)); // Calculate timestamps - based on predicted durations. - std::vector timestamps = - calculateTimestamps(predDurPtr, inputSize); + std::vector timestamps = calculateTimestamps(predDurPtr, inputSize); /** * Returns: @@ -135,24 +123,22 @@ DurationPredictor::generate(std::span tokens, std::span textMask, * - indices: vector of repeated token indices according to durations. * - effDuration: an effective duration after post-processing. */ - return std::make_tuple(std::move(dTensor), std::move(indices), - std::move(effDuration), std::move(timestamps)); + return std::make_tuple(std::move(dTensor), std::move(indices), std::move(effDuration), + std::move(timestamps)); } size_t DurationPredictor::getTokensLimit() const { return forwardMethods_.empty() ? 0 : forwardMethods_.back().second; } -std::vector -DurationPredictor::calculateTimestamps(const int64_t *predDurPtr, - size_t inputSize) const { +std::vector DurationPredictor::calculateTimestamps(const int64_t *predDurPtr, + size_t inputSize) const { std::vector timestamps; timestamps.reserve(inputSize); size_t accDur = 0; for (size_t i = 0; i < inputSize; i++) { - int64_t dur = predDurPtr[i] * - constants::kTicksPerDuration; // Convert to audio samples + int64_t dur = predDurPtr[i] * constants::kTicksPerDuration; // Convert to audio samples timestamps.emplace_back(accDur, accDur + dur); accDur += dur; } @@ -163,18 +149,15 @@ DurationPredictor::calculateTimestamps(const int64_t *predDurPtr, void DurationPredictor::scaleDurations(Tensor &durations, size_t nTokens, int32_t targetDuration) const { // We expect durations tensor to be a Long tensor of a shape [1, n_tokens] - if (durations.dtype() != ScalarType::Long && - durations.dtype() != ScalarType::Int) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidModelOutput, - "[Kokoro::DurationPredictor] Attempted to scale a non-integer tensor"); + if (durations.dtype() != ScalarType::Long && durations.dtype() != ScalarType::Int) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput, + "[Kokoro::DurationPredictor] Attempted to scale a non-integer tensor"); } auto shape = durations.sizes(); if (shape.size() != 1) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidModelOutput, - "[Kokoro::DurationPredictor] Attempted to scale an ill-shaped tensor"); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput, + "[Kokoro::DurationPredictor] Attempted to scale an ill-shaped tensor"); } int64_t *durationsPtr = durations.mutable_data_ptr(); @@ -186,16 +169,13 @@ void DurationPredictor::scaleDurations(Tensor &durations, size_t nTokens, // We need to scale partial durations (integers) corresponding to each token // in a way that they all sum up to target duration, while keeping the balance // between the values. - std::priority_queue> - remainders; // Sorted by the first value + std::priority_queue> remainders; // Sorted by the first value int64_t scaledSum = 0; for (uint32_t i = 0; i < nTokens; i++) { float scaled = scaleFactor * durationsPtr[i]; - float remainder = - shrinking ? std::ceil(scaled) - scaled : scaled - std::floor(scaled); + float remainder = shrinking ? std::ceil(scaled) - scaled : scaled - std::floor(scaled); - durationsPtr[i] = static_cast(shrinking ? std::ceil(scaled) - : std::floor(scaled)); + durationsPtr[i] = static_cast(shrinking ? std::ceil(scaled) : std::floor(scaled)); scaledSum += durationsPtr[i]; // Keeps the entries sorted by the remainders diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.h index 2ace0b9b25..826d95cf10 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/DurationPredictor.h @@ -18,8 +18,7 @@ using executorch::aten::Tensor; class DurationPredictor : public BaseModel { public: - explicit DurationPredictor(const std::string &modelSource, - const Context &modelContext, + explicit DurationPredictor(const std::string &modelSource, const Context &modelContext, std::shared_ptr callInvoker); /** @@ -39,30 +38,27 @@ class DurationPredictor : public BaseModel { * timestamps - timestamp marks for each token (phoneme) */ std::tuple, int32_t, std::vector> - generate(std::span tokens, std::span textMask, - std::span ref_hs, float speed = 1.F); + generate(std::span tokens, std::span textMask, std::span ref_hs, + float speed = 1.F); // Returns maximum supported amount of input tokens. size_t getTokensLimit() const; private: // Helper function - calculating timestamps based on predicted durations - std::vector calculateTimestamps(const int64_t *predDurPtr, - size_t inputSize) const; + std::vector calculateTimestamps(const int64_t *predDurPtr, size_t inputSize) const; // Helper function - duration scalling // Performs integer scaling on the durations tensor to ensure the sum of // durations matches the given target duration - void scaleDurations( - Tensor &durations, size_t nTokens, - int32_t targetDuration) const; // Helper function - calculating effective - // duration based on duration tensor + void scaleDurations(Tensor &durations, size_t nTokens, + int32_t targetDuration) const; // Helper function - calculating effective + // duration based on duration tensor // Since we apply padding to the input, the effective duration is // usually a little bit lower than the max duration defined by static input // size. - int32_t calculateEffectiveDuration(const Tensor &d, - const std::vector &indices) const; + int32_t calculateEffectiveDuration(const Tensor &d, const std::vector &indices) const; // Available forward methods // In order to speed-up the calculations, we allow DurationPredictor to diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp index aa8486c276..058846d477 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.cpp @@ -13,29 +13,23 @@ namespace rnexecutorch::models::text_to_speech::kokoro { Kokoro::Kokoro(const std::string &lang, const std::string &taggerDataSource, - const std::string &lexiconSource, - const std::string &neuralModelSource, - const std::string &durationPredictorSource, - const std::string &synthesizerSource, - const std::string &voiceSource, - std::shared_ptr callInvoker) + const std::string &lexiconSource, const std::string &neuralModelSource, + const std::string &durationPredictorSource, const std::string &synthesizerSource, + const std::string &voiceSource, std::shared_ptr callInvoker) : callInvoker_(std::move(callInvoker)), phonemizer_(phonemis::Config{ .lang = lang, - .tagger = taggerDataSource.empty() - ? std::optional{} - : std::make_optional(phonemis::tagger::Config{ - .data_filepath = taggerDataSource}), + .tagger = taggerDataSource.empty() ? std::optional{} + : std::make_optional(phonemis::tagger::Config{ + .data_filepath = taggerDataSource}), .phonemizer = phonemis::phonemizer::Config{ .lang = lang, - .lexicon_filepath = lexiconSource.empty() - ? std::nullopt - : std::make_optional(lexiconSource), - .nn_model_filepath = - neuralModelSource.empty() - ? std::nullopt - : std::make_optional(neuralModelSource)}}), + .lexicon_filepath = + lexiconSource.empty() ? std::nullopt : std::make_optional(lexiconSource), + .nn_model_filepath = neuralModelSource.empty() + ? std::nullopt + : std::make_optional(neuralModelSource)}}), durationPredictor_(durationPredictorSource, context_, callInvoker_), synthesizer_(synthesizerSource, context_, callInvoker_) { // Populate the voice array by reading given file @@ -43,9 +37,8 @@ Kokoro::Kokoro(const std::string &lang, const std::string &taggerDataSource, // Read model limits & check compatibility if (durationPredictor_.getTokensLimit() != synthesizer_.getTokensLimit()) { - throw RnExecutorchError( - RnExecutorchErrorCode::WrongDimensions, - "[Kokoro] incompatible DurationPredictor & Synthesizer models"); + throw RnExecutorchError(RnExecutorchErrorCode::WrongDimensions, + "[Kokoro] incompatible DurationPredictor & Synthesizer models"); } context_.inputTokensLimit = durationPredictor_.getTokensLimit(); @@ -59,8 +52,7 @@ void Kokoro::loadVoice(const std::string &voiceSource) { std::ifstream in(voiceSource, std::ios::binary); if (!in) { throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, - "[Kokoro::loadVoice]: cannot open file: " + - voiceSource); + "[Kokoro::loadVoice]: cannot open file: " + voiceSource); } // Determine number of rows from file size @@ -69,11 +61,10 @@ void Kokoro::loadVoice(const std::string &voiceSource) { in.seekg(0, std::ios::beg); if (fileSize < bytesPerRow) { - throw RnExecutorchError( - RnExecutorchErrorCode::FileReadFailed, - "[Kokoro::loadVoice]: file too small: need at least " + - std::to_string(bytesPerRow) + " bytes for one row, got " + - std::to_string(fileSize)); + throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, + "[Kokoro::loadVoice]: file too small: need at least " + + std::to_string(bytesPerRow) + " bytes for one row, got " + + std::to_string(fileSize)); } const size_t rows = fileSize / bytesPerRow; @@ -83,14 +74,12 @@ void Kokoro::loadVoice(const std::string &voiceSource) { voice_.resize(rows); if (!in.read(reinterpret_cast(voice_.data()->data()), readBytes)) { - throw RnExecutorchError( - RnExecutorchErrorCode::FileReadFailed, - "[Kokoro::loadVoice]: failed to read voice weights"); + throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, + "[Kokoro::loadVoice]: failed to read voice weights"); } } -std::vector Kokoro::generate(std::u32string input, float speed, - bool phonemize) { +std::vector Kokoro::generate(std::u32string input, float speed, bool phonemize) { if (input.size() > params::kMaxTextSize) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "Kokoro: maximum input text size exceeded"); @@ -99,15 +88,13 @@ std::vector Kokoro::generate(std::u32string input, float speed, if (speed < constants::kMinValidSpeed) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "Kokoro: speed value too low (min " + - std::to_string(constants::kMinValidSpeed) + - ")"); + std::to_string(constants::kMinValidSpeed) + ")"); } if (speed > constants::kMaxValidSpeed) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "Kokoro: speed value too high (max " + - std::to_string(constants::kMaxValidSpeed) + - ")"); + std::to_string(constants::kMaxValidSpeed) + ")"); } if (input.empty()) { @@ -119,8 +106,8 @@ std::vector Kokoro::generate(std::u32string input, float speed, // Divide the phonemes string into substrings, minimizing the amount of // breaks. - auto partition = partitioner_.partition(phonemes, context_.inputTokensLimit, - Partitioner::Mode::MIN_BREAKS); + auto partition = + partitioner_.partition(phonemes, context_.inputTokensLimit, Partitioner::Mode::MIN_BREAKS); std::vector audio = {}; for (const auto &[offset, length] : partition.segments) { @@ -138,37 +125,32 @@ std::vector Kokoro::generate(std::u32string input, float speed, // Add audio part and silence pause to the main audio vector audio.insert(audio.end(), std::make_move_iterator(audioPart.begin()), std::make_move_iterator(audioPart.end())); - audio.resize(audio.size() + pauseMs * constants::kSamplesPerMilisecond, - 0.F); + audio.resize(audio.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F); } return audio; } -void Kokoro::stream(std::shared_ptr callback, float speed, - bool phonemize, bool stopOnEmptyBuffer) { +void Kokoro::stream(std::shared_ptr callback, float speed, bool phonemize, + bool stopOnEmptyBuffer) { if (speed < constants::kMinValidSpeed) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "Kokoro: speed value too low (min " + - std::to_string(constants::kMinValidSpeed) + - ")"); + std::to_string(constants::kMinValidSpeed) + ")"); } if (speed > constants::kMaxValidSpeed) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "Kokoro: speed value too high (max " + - std::to_string(constants::kMaxValidSpeed) + - ")"); + std::to_string(constants::kMaxValidSpeed) + ")"); } // Create a callback auto nativeCallback = [this, callback](const std::vector &audioVec) { if (this->isStreaming_) { - this->callInvoker_->invokeAsync( - [callback, audioVec = std::move(audioVec)](jsi::Runtime &rt) { - callback->call( - rt, rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt)); - }); + this->callInvoker_->invokeAsync([callback, audioVec = std::move(audioVec)](jsi::Runtime &rt) { + callback->call(rt, rnexecutorch::jsi_conversion::getJsiValue(audioVec, rt)); + }); } }; @@ -187,8 +169,7 @@ void Kokoro::stream(std::shared_ptr callback, float speed, { // Trim to remove trailing whitespace characters inputTextBuffer_ = - phonemis::utils::strings::strip( - inputTextBuffer_); + phonemis::utils::strings::strip(inputTextBuffer_); std::scoped_lock lock(inputTextBufferMutex_); if (inputTextBuffer_.empty() && stopOnEmptyBuffer_) { @@ -196,15 +177,13 @@ void Kokoro::stream(std::shared_ptr callback, float speed, } // Try to find the most recent available end of sentence character. - size_t searchLimit = - std::min(inputTextBuffer_.size(), params::kMaxTextSize); - auto eosIt = std::find_first_of( - inputTextBuffer_.rbegin() + (inputTextBuffer_.size() - searchLimit), - inputTextBuffer_.rend(), constants::kEndOfSentenceCharacters.begin(), - constants::kEndOfSentenceCharacters.end()); - size_t chunkSize = (eosIt != inputTextBuffer_.rend()) - ? std::distance(eosIt, inputTextBuffer_.rend()) - : 0; + size_t searchLimit = std::min(inputTextBuffer_.size(), params::kMaxTextSize); + auto eosIt = + std::find_first_of(inputTextBuffer_.rbegin() + (inputTextBuffer_.size() - searchLimit), + inputTextBuffer_.rend(), constants::kEndOfSentenceCharacters.begin(), + constants::kEndOfSentenceCharacters.end()); + size_t chunkSize = + (eosIt != inputTextBuffer_.rend()) ? std::distance(eosIt, inputTextBuffer_.rend()) : 0; // Default behavior: hold back partial content until an EOS arrives, so // we don't synthesize mid-sentence (relevant for LLM token streaming). @@ -237,8 +216,8 @@ void Kokoro::stream(std::shared_ptr callback, float speed, // Since we do not phonemize the entire input before partitioning, there // is a possibility that some segment might exceed the token limit after // phonemization. This is being handled later. - auto partition = partitioner_.partition( - buffer, context_.inputTokensLimit, Partitioner::Mode::MIN_LATENCY); + auto partition = partitioner_.partition(buffer, context_.inputTokensLimit, + Partitioner::Mode::MIN_LATENCY); for (size_t i = 0; i < partition.segments.size(); i++) { if (!isStreaming_) { @@ -253,8 +232,7 @@ void Kokoro::stream(std::shared_ptr callback, float speed, if (phonemize) { size_t unchangedLength = std::min(length, phonemizedTokens); // Include trailing space if it was already phonemized - if (unchangedLength < length && - subsentence[unchangedLength] == U' ' && + if (unchangedLength < length && subsentence[unchangedLength] == U' ' && phonemizedTokens > unchangedLength) { unchangedLength++; } @@ -265,8 +243,7 @@ void Kokoro::stream(std::shared_ptr callback, float speed, phonemes = subsentence.substr(0, unchangedLength); if (unchangedLength < length) { // Phonemize without preprocessing (since we already did that). - phonemes += - phonemizer_(subsentence.substr(unchangedLength), false); + phonemes += phonemizer_(subsentence.substr(unchangedLength), false); } } else { // Simple case - no phonemization, no risk of exceeding the token @@ -277,8 +254,7 @@ void Kokoro::stream(std::shared_ptr callback, float speed, if (phonemes.size() <= context_.inputTokensLimit - 2) { // Determine the silent padding duration bool endsWithSpace = (subsentence.back() == U' '); - bool prevEndsWithSpace = - (offset > 0 && partition.content[offset - 1] == U' '); + bool prevEndsWithSpace = (offset > 0 && partition.content[offset - 1] == U' '); size_t paddingMs = endsWithSpace || prevEndsWithSpace ? 15 : 50; // Generate and push audio @@ -288,9 +264,7 @@ void Kokoro::stream(std::shared_ptr callback, float speed, ? params::kPauseValues.at(phonemes.back()) : params::kDefaultPause; - audioPart.resize(audioPart.size() + - pauseMs * constants::kSamplesPerMilisecond, - 0.F); + audioPart.resize(audioPart.size() + pauseMs * constants::kSamplesPerMilisecond, 0.F); nativeCallback(std::move(audioPart)); @@ -315,8 +289,7 @@ void Kokoro::stream(std::shared_ptr callback, float speed, // A little bit of pause to not overload the thread. if (isStreaming_) { - std::this_thread::sleep_for( - std::chrono::milliseconds(params::kStreamPause)); + std::this_thread::sleep_for(std::chrono::milliseconds(params::kStreamPause)); } } @@ -328,8 +301,7 @@ void Kokoro::stream(std::shared_ptr callback, float speed, } } -std::vector Kokoro::synthesize(std::u32string_view phonemes, float speed, - size_t paddingMs) { +std::vector Kokoro::synthesize(std::u32string_view phonemes, float speed, size_t paddingMs) { if (phonemes.empty()) { return {}; } @@ -343,41 +315,33 @@ std::vector Kokoro::synthesize(std::u32string_view phonemes, float speed, // Clamp input to avoid exceeding model limits (2 tokens reserved for pre/post // padding). const size_t noTokens = - std::clamp(phonemes.size() + 2, constants::kMinInputTokens, - context_.inputTokensLimit); + std::clamp(phonemes.size() + 2, constants::kMinInputTokens, context_.inputTokensLimit); auto tokens = utils::tokenize(phonemes, {noTokens}); // 2. Initialize text mask. // Exclude all paddings except the first and last ones. // We use uint8_t instead of bool to avoid boolean span issues. std::vector textMask(noTokens, false); - std::fill(textMask.begin(), - textMask.begin() + std::min(phonemes.size() + 2, noTokens), true); + std::fill(textMask.begin(), textMask.begin() + std::min(phonemes.size() + 2, noTokens), true); // 3. Select the appropriate voice vector. // Each number of input tokens corresponds to a different voice embedding // vector. - const size_t voiceID = - std::min({phonemes.size() - 1, noTokens - 1, voice_.size() - 1}); + const size_t voiceID = std::min({phonemes.size() - 1, noTokens - 1, voice_.size() - 1}); auto &voice = voice_[voiceID]; // 4. Inference Phase 1: DurationPredictor (submodule). - auto [d, indices, effectiveDuration, timestamps] = - durationPredictor_.generate( - std::span(tokens), - std::span(reinterpret_cast(textMask.data()), textMask.size()), - std::span(voice).last(constants::kVoiceRefHalfSize), speed); + auto [d, indices, effectiveDuration, timestamps] = durationPredictor_.generate( + std::span(tokens), std::span(reinterpret_cast(textMask.data()), textMask.size()), + std::span(voice).last(constants::kVoiceRefHalfSize), speed); // 5. Inference Phase 2: Synthesizer. // Note that we reduce the size of the duration tensor to match the number of // tokens. auto decoding = synthesizer_.generate( - std::span(tokens), - std::span(reinterpret_cast(textMask.data()), textMask.size()), + std::span(tokens), std::span(reinterpret_cast(textMask.data()), textMask.size()), std::span(indices), - std::span(d.mutable_data_ptr(), - noTokens * d.sizes().back()), - std::span(voice)); + std::span(d.mutable_data_ptr(), noTokens * d.sizes().back()), std::span(voice)); // 6. Post-processing: Finalize audio. auto audioTensor = decoding->at(0).toTensor(); @@ -388,8 +352,7 @@ std::vector Kokoro::synthesize(std::u32string_view phonemes, float speed, const int32_t audioLength = constants::kTicksPerDuration * effectiveDuration; - auto audio = - std::span(audioTensor.const_data_ptr(), audioLength); + auto audio = std::span(audioTensor.const_data_ptr(), audioLength); // To counter any potential trailing voice artifacts (which can occur due to // slight mismatch of .pte model results) we cut it according to the predicted @@ -397,17 +360,15 @@ std::vector Kokoro::synthesize(std::u32string_view phonemes, float speed, if (noTokens > 2) { // We want to skip both the last PAD token, as well as any potential EOS // token just before it. - auto lastTokenTimestamp = - !phonemis::utils::unicode::isalpha(phonemes.back()) - ? timestamps[noTokens - 3].end - : timestamps[noTokens - 2].end; + auto lastTokenTimestamp = !phonemis::utils::unicode::isalpha(phonemes.back()) + ? timestamps[noTokens - 3].end + : timestamps[noTokens - 2].end; audio = audio.subspan(0, std::min(lastTokenTimestamp, audio.size())); } // Now additional stripping of a (hopefully) pure silence. - audio = - utils::stripAudio(audio, paddingMs * constants::kSamplesPerMilisecond); + audio = utils::stripAudio(audio, paddingMs * constants::kSamplesPerMilisecond); return {audio.begin(), audio.end()}; } @@ -430,9 +391,8 @@ void Kokoro::streamStop(bool instant) noexcept { } std::size_t Kokoro::getMemoryLowerBound() const noexcept { - return durationPredictor_.getMemoryLowerBound() + - synthesizer_.getMemoryLowerBound() + sizeof(voice_) + - sizeof(phonemizer_); + return durationPredictor_.getMemoryLowerBound() + synthesizer_.getMemoryLowerBound() + + sizeof(voice_) + sizeof(phonemizer_); } void Kokoro::unload() noexcept { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h index ef9fa432b6..11922b67ba 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Kokoro.h @@ -21,9 +21,8 @@ class Kokoro { public: Kokoro(const std::string &lang, const std::string &taggerDataSource, const std::string &lexiconSource, const std::string &neuralModelSource, - const std::string &durationPredictorSource, - const std::string &synthesizerSource, const std::string &voiceSource, - std::shared_ptr callInvoker); + const std::string &durationPredictorSource, const std::string &synthesizerSource, + const std::string &voiceSource, std::shared_ptr callInvoker); /** * Generates complete audio for the provided text. @@ -35,8 +34,7 @@ class Kokoro { * operates on raw input. * @return A vector of PCM float samples representing the synthesized speech. */ - std::vector generate(std::u32string input, float speed = 1.F, - bool phonemize = true); + std::vector generate(std::u32string input, float speed = 1.F, bool phonemize = true); /** * Starts an asynchronous streaming process that processes text in chunks. @@ -50,8 +48,8 @@ class Kokoro { * @param stopOnEmptyBuffer If true, streaming terminates automatically when * the buffer is exhausted. */ - void stream(std::shared_ptr callback, float speed = 1.F, - bool phonemize = true, bool stopOnEmptyBuffer = false); + void stream(std::shared_ptr callback, float speed = 1.F, bool phonemize = true, + bool stopOnEmptyBuffer = false); /** * Appends new input data (either text or phonemes) to the buffer. @@ -87,8 +85,7 @@ class Kokoro { private: // --- Initialization & Core Inference --- void loadVoice(const std::string &voiceSource); - std::vector synthesize(std::u32string_view phonemes, float speed, - size_t paddingMs = 50); + std::vector synthesize(std::u32string_view phonemes, float speed, size_t paddingMs = 50); // --- External Dependencies --- std::shared_ptr callInvoker_; @@ -121,9 +118,8 @@ class Kokoro { }; } // namespace models::text_to_speech::kokoro -REGISTER_CONSTRUCTOR(models::text_to_speech::kokoro::Kokoro, std::string, +REGISTER_CONSTRUCTOR(models::text_to_speech::kokoro::Kokoro, std::string, std::string, std::string, std::string, std::string, std::string, std::string, - std::string, std::string, std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Params.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Params.h index 5f61287f02..456522ec29 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Params.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Params.h @@ -33,9 +33,16 @@ inline constexpr int32_t kStreamPause = 200; * (ms). */ inline const std::unordered_map kPauseValues = { - {U'.', 375}, {U'?', 500}, {U'!', 250}, {U';', 400}, {U'…', 600}, // Ellipsis - {U',', 130}, {U':', 250}, {U'-', 200}, {U'—', 250}, // Em Dash (slightly - // longer than hyphen) + {U'.', 375}, + {U'?', 500}, + {U'!', 250}, + {U';', 400}, + {U'…', 600}, // Ellipsis + {U',', 130}, + {U':', 250}, + {U'-', 200}, + {U'—', 250}, // Em Dash (slightly + // longer than hyphen) {U'|', 375}, // ASCII Pipe (treated as full stop) {U'।', 375}, // Hindi Purna Viram {U'॥', 500}, // Hindi Deergh Viram (typically longer than Purna Viram) @@ -99,4 +106,4 @@ inline constexpr uint64_t kWhiteMinLatencyCost = 1000; } // namespace partitioning -} // namespace rnexecutorch::models::text_to_speech::kokoro::params \ No newline at end of file +} // namespace rnexecutorch::models::text_to_speech::kokoro::params diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.cpp index d994d98d74..d9075c965d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.cpp @@ -14,11 +14,10 @@ using namespace params::partitioning; // Custom infinity definition constexpr Partitioner::Cost INF = 1e7; -Partitioner::Partition Partitioner::partition(std::u32string_view input, - size_t limit, Mode mode) const { +Partitioner::Partition Partitioner::partition(std::u32string_view input, size_t limit, + Mode mode) const { if (mode == Mode::MIN_BREAKS) { - auto minBreakCostFn = [limit](Cost acc, size_t beg, int64_t prevBp, - int64_t bp, size_t end, + auto minBreakCostFn = [limit](Cost acc, size_t beg, int64_t prevBp, int64_t bp, size_t end, Separator sep) -> Cost { if (end - bp > limit) { return INF; @@ -36,8 +35,7 @@ Partitioner::Partition Partitioner::partition(std::u32string_view input, } if (mode == Mode::MIN_LATENCY) { - auto minLatencyCostFn = [limit](Cost acc, size_t beg, int64_t prevBp, - int64_t bp, size_t end, + auto minLatencyCostFn = [limit](Cost acc, size_t beg, int64_t prevBp, int64_t bp, size_t end, Separator sep) -> Cost { if (end - bp > limit) { return INF; @@ -51,14 +49,11 @@ Partitioner::Partition Partitioner::partition(std::u32string_view input, int64_t rightmostRangeLength = end - bp; int64_t prevRangeLength = bp - prevBp; - int64_t latency = std::max(static_cast(0), - rightmostRangeLength - prevRangeLength); + int64_t latency = std::max(static_cast(0), rightmostRangeLength - prevRangeLength); int64_t discount = - kTokenDiscountFactor * - std::max(static_cast(0), kTokenDiscountRange - bp - 1); + kTokenDiscountFactor * std::max(static_cast(0), kTokenDiscountRange - bp - 1); - return acc + static_cast(latency * discount / kTokenDiscountRange) + - sepPenalty; + return acc + static_cast(latency * discount / kTokenDiscountRange) + sepPenalty; }; return partition(input, limit, minLatencyCostFn); @@ -67,8 +62,7 @@ Partitioner::Partition Partitioner::partition(std::u32string_view input, return {input, {}}; } -Partitioner::Partition Partitioner::partition(std::u32string_view input, - size_t limit, +Partitioner::Partition Partitioner::partition(std::u32string_view input, size_t limit, CostFn costFn) const { if (input.empty()) { return {input, {}}; @@ -93,8 +87,7 @@ Partitioner::Partition Partitioner::partition(std::u32string_view input, : q == &pausePoints ? Separator::PAUSE : Separator::WHITE; for (size_t breakIdx : (*q)) { - auto cost = costFn(dp[breakIdx].first, 0, dp[breakIdx].second, breakIdx, - i, sep); + auto cost = costFn(dp[breakIdx].first, 0, dp[breakIdx].second, breakIdx, i, sep); if (cost < bestCost && breakIdx > 0) { bestCost = cost; prevBpIdx = breakIdx; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.h index 93f2b97c84..cf12054064 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Partitioner.h @@ -18,9 +18,8 @@ class Partitioner { * divided. */ enum class Mode { - MIN_BREAKS = 0, // Minimizes number of substrings (best quality) - MIN_LATENCY = - 1, // Minimizes the processing latency (best speed - streaming mode) + MIN_BREAKS = 0, // Minimizes number of substrings (best quality) + MIN_LATENCY = 1, // Minimizes the processing latency (best speed - streaming mode) }; /** @@ -52,8 +51,8 @@ class Partitioner { * @param end End index of the current range (inclusive). * @param sep The type of the breakpoint. */ - using CostFn = std::function; + using CostFn = std::function; /** * Holds the result of text partitioning. @@ -62,8 +61,7 @@ class Partitioner { */ struct Partition { std::u32string_view content; - std::vector> - segments; // Pairs of {offset, length} for each segment. + std::vector> segments; // Pairs of {offset, length} for each segment. }; /** @@ -76,13 +74,11 @@ class Partitioner { * @return A Partition object containing the original content view and * breakpoints. */ - Partition partition(std::u32string_view input, size_t limit, - Mode mode = Mode::MIN_LATENCY) const; + Partition partition(std::u32string_view input, size_t limit, Mode mode = Mode::MIN_LATENCY) const; private: // Internal partition implementation that uses a specific cost function. - Partition partition(std::u32string_view input, size_t limit, - CostFn costFn) const; + Partition partition(std::u32string_view input, size_t limit, CostFn costFn) const; }; -} // namespace rnexecutorch::models::text_to_speech::kokoro \ No newline at end of file +} // namespace rnexecutorch::models::text_to_speech::kokoro diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp index 9029759caf..ea513408a8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.cpp @@ -13,8 +13,7 @@ using ::executorch::aten::ScalarType; using ::executorch::extension::make_tensor_ptr; using ::executorch::extension::TensorPtr; -Synthesizer::Synthesizer(const std::string &modelSource, - const Context &modelContext, +Synthesizer::Synthesizer(const std::string &modelSource, const Context &modelContext, std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker), context_(modelContext) { // Discover all forward methods (forward, forward_8, forward_32, etc.) @@ -32,9 +31,8 @@ Synthesizer::Synthesizer(const std::string &modelSource, forwardMethods_.emplace_back(name, inputSize); } } - std::ranges::stable_sort(forwardMethods_, [](const auto &a, const auto &b) { - return a.second < b.second; - }); + std::ranges::stable_sort(forwardMethods_, + [](const auto &a, const auto &b) { return a.second < b.second; }); } // Fallback: if no methods discovered, validate "forward" directly @@ -48,10 +46,8 @@ Synthesizer::Synthesizer(const std::string &modelSource, } } -Result> Synthesizer::generate(std::span tokens, - std::span textMask, - std::span indices, - std::span dur, +Result> Synthesizer::generate(std::span tokens, std::span textMask, + std::span indices, std::span dur, std::span ref_s) { // Perform input shape checks // Both F0 and N vectors should be twice as long as duration @@ -62,37 +58,31 @@ Result> Synthesizer::generate(std::span tokens, int32_t duration = indices.size(); // Convert input data to ExecuTorch tensors - auto tokensTensor = make_tensor_ptr({1, static_cast(tokens.size())}, - tokens.data(), ScalarType::Long); - auto textMaskTensor = - make_tensor_ptr({1, static_cast(textMask.size())}, - textMask.data(), ScalarType::Bool); - auto indicesTensor = - make_tensor_ptr({duration}, indices.data(), ScalarType::Long); - auto durTensor = - make_tensor_ptr({1, noTokens, 640}, dur.data(), ScalarType::Float); - auto voiceRefTensor = make_tensor_ptr({1, constants::kVoiceRefSize}, - ref_s.data(), ScalarType::Float); + auto tokensTensor = + make_tensor_ptr({1, static_cast(tokens.size())}, tokens.data(), ScalarType::Long); + auto textMaskTensor = make_tensor_ptr({1, static_cast(textMask.size())}, textMask.data(), + ScalarType::Bool); + auto indicesTensor = make_tensor_ptr({duration}, indices.data(), ScalarType::Long); + auto durTensor = make_tensor_ptr({1, noTokens, 640}, dur.data(), ScalarType::Float); + auto voiceRefTensor = + make_tensor_ptr({1, constants::kVoiceRefSize}, ref_s.data(), ScalarType::Float); // Select appropriate forward method based on token count - auto it = - std::ranges::find_if(forwardMethods_, [noTokens](const auto &entry) { - return std::cmp_greater_equal(entry.second, noTokens); - }); + auto it = std::ranges::find_if(forwardMethods_, [noTokens](const auto &entry) { + return std::cmp_greater_equal(entry.second, noTokens); + }); std::string selectedMethod = (it != forwardMethods_.end()) ? it->first : forwardMethods_.back().first; // Execute the selected forward method - auto results = - execute(selectedMethod, {tokensTensor, textMaskTensor, indicesTensor, - durTensor, voiceRefTensor}); + auto results = execute(selectedMethod, + {tokensTensor, textMaskTensor, indicesTensor, durTensor, voiceRefTensor}); if (!results.ok()) { throw RnExecutorchError( RnExecutorchErrorCode::InvalidModelOutput, "[Kokoro::Synthesizer] Failed to execute method " + selectedMethod + - ", error: " + - std::to_string(static_cast(results.error()))); + ", error: " + std::to_string(static_cast(results.error()))); } // Returns a single [audio] vector, which contains the @@ -110,4 +100,4 @@ size_t Synthesizer::getDurationLimit() const { return getInputShape(forwardMethods_.back().first, 2)[0]; } -} // namespace rnexecutorch::models::text_to_speech::kokoro \ No newline at end of file +} // namespace rnexecutorch::models::text_to_speech::kokoro diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h index c3a21957db..2c99ee086d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Synthesizer.h @@ -21,8 +21,7 @@ namespace rnexecutorch::models::text_to_speech::kokoro { */ class Synthesizer : public BaseModel { public: - explicit Synthesizer(const std::string &modelSource, - const Context &modelContext, + explicit Synthesizer(const std::string &modelSource, const Context &modelContext, std::shared_ptr callInvoker); /** @@ -39,10 +38,8 @@ class Synthesizer : public BaseModel { * @param dur duration values, obtained from DurationPredictor module * @param ref_s a full voice array for given duration */ - Result> generate(std::span tokens, - std::span textMask, - std::span indices, - std::span dur, + Result> generate(std::span tokens, std::span textMask, + std::span indices, std::span dur, std::span ref_s); // Model limits getters diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp index f1bf7f8d4d..a893314dab 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.cpp @@ -12,9 +12,7 @@ using namespace params::cropping; namespace { -float normalize(float sample) { - return std::max(0.0F, std::abs(sample) - kAudioSilenceThreshold); -} +float normalize(float sample) { return std::max(0.0F, std::abs(sample) - kAudioSilenceThreshold); } template size_t findAudioBound(std::span audio) { if (audio.empty()) { @@ -32,8 +30,8 @@ template size_t findAudioBound(std::span audio) { // Maintain the sliding window sum if (processedCount > kAudioCroppingSteps) { - const size_t oldIndex = reverse ? currentIndex + kAudioCroppingSteps - : currentIndex - kAudioCroppingSteps; + const size_t oldIndex = + reverse ? currentIndex + kAudioCroppingSteps : currentIndex - kAudioCroppingSteps; windowSum -= normalize(audio[oldIndex]); } @@ -67,8 +65,7 @@ std::span stripAudio(std::span audio, size_t margin) { return audio.subspan(lbound, strippedLength); } -std::vector tokenize(std::u32string_view phonemes, - std::optional expectedSize) { +std::vector tokenize(std::u32string_view phonemes, std::optional expectedSize) { if (expectedSize.has_value() && expectedSize.value() < 2) { throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, "[Kokoro::Utils] Expected tokens must be >= 2"); @@ -84,22 +81,20 @@ std::vector tokenize(std::u32string_view phonemes, // 3. Map phonemes to vocabulary tokens // Starting from index 1 to leave index 0 as start-padding - std::transform(phonemes.begin(), phonemes.begin() + effectivePhonemeCount, - tokens.begin() + 1, [](char32_t p) -> Token { - return constants::kVocab.contains(p) - ? constants::kVocab.at(p) - : constants::kInvalidToken; + std::transform(phonemes.begin(), phonemes.begin() + effectivePhonemeCount, tokens.begin() + 1, + [](char32_t p) -> Token { + return constants::kVocab.contains(p) ? constants::kVocab.at(p) + : constants::kInvalidToken; }); // 4. Remove invalid tokens while preserving order (bubbling them to the end // of the content segment) - auto validEnd = std::stable_partition( - tokens.begin() + 1, tokens.begin() + effectivePhonemeCount + 1, - [](Token t) { return t != constants::kInvalidToken; }); + auto validEnd = + std::stable_partition(tokens.begin() + 1, tokens.begin() + effectivePhonemeCount + 1, + [](Token t) { return t != constants::kInvalidToken; }); // 5. Fill any gaps created by partitioning or sizing with pad tokens - std::fill(validEnd, tokens.begin() + effectivePhonemeCount + 1, - constants::kPadToken); + std::fill(validEnd, tokens.begin() + effectivePhonemeCount + 1, constants::kPadToken); return tokens; } diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.h index c6996a3f40..69d6202d91 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_speech/kokoro/Utils.h @@ -13,8 +13,7 @@ namespace rnexecutorch::models::text_to_speech::kokoro::utils { * @param audio The input audio samples. * @param margin Number of silence samples to preserve at each edge. */ -std::span stripAudio(std::span audio, - size_t margin = 0); +std::span stripAudio(std::span audio, size_t margin = 0); /** * Maps phonemes to vocabulary tokens with start/end padding. @@ -24,4 +23,4 @@ std::span stripAudio(std::span audio, std::vector tokenize(std::u32string_view phonemes, std::optional expectedSize = std::nullopt); -} // namespace rnexecutorch::models::text_to_speech::kokoro::utils \ No newline at end of file +} // namespace rnexecutorch::models::text_to_speech::kokoro::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.cpp index 10fe8bfc35..a8ec5a6c4a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.cpp @@ -9,42 +9,36 @@ #include namespace rnexecutorch::models::ocr { -VerticalDetector::VerticalDetector( - const std::string &modelSource, - std::shared_ptr callInvoker) +VerticalDetector::VerticalDetector(const std::string &modelSource, + std::shared_ptr callInvoker) : Detector(modelSource, callInvoker) {}; -std::vector -VerticalDetector::generate(const cv::Mat &inputImage, int32_t inputWidth) { +std::vector VerticalDetector::generate(const cv::Mat &inputImage, + int32_t inputWidth) { - bool detectSingleCharacters = - !(inputWidth >= constants::kMediumDetectorWidth); + bool detectSingleCharacters = !(inputWidth >= constants::kMediumDetectorWidth); - utils::validateInputWidth(inputWidth, constants::kDetectorInputWidths, - "VerticalDetector"); + utils::validateInputWidth(inputWidth, constants::kDetectorInputWidths, "VerticalDetector"); std::string methodName = "forward_" + std::to_string(inputWidth); auto inputShapes = getAllInputShapes(methodName); cv::Size modelInputSize = calculateModelImageSize(inputWidth); - cv::Mat resizedInputImage = - image_processing::resizePadded(inputImage, modelInputSize); - TensorPtr inputTensor = image_processing::getTensorFromMatrix( - inputShapes[0], resizedInputImage, constants::kNormalizationMean, - constants::kNormalizationVariance); + cv::Mat resizedInputImage = image_processing::resizePadded(inputImage, modelInputSize); + TensorPtr inputTensor = image_processing::getTensorFromMatrix(inputShapes[0], resizedInputImage, + constants::kNormalizationMean, + constants::kNormalizationVariance); auto forwardResult = BaseModel::execute(methodName, {inputTensor}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); - return postprocess(forwardResult->at(0).toTensor(), - calculateModelImageSize(inputWidth), + return postprocess(forwardResult->at(0).toTensor(), calculateModelImageSize(inputWidth), detectSingleCharacters); } -std::vector -VerticalDetector::postprocess(const Tensor &tensor, - const cv::Size &modelInputSize, - bool detectSingleCharacters) const { +std::vector VerticalDetector::postprocess(const Tensor &tensor, + const cv::Size &modelInputSize, + bool detectSingleCharacters) const { /* The output of the model consists of two matrices (heat maps): 1. ScoreText(Score map) - The probability of a region containing character. @@ -54,32 +48,28 @@ VerticalDetector::postprocess(const Tensor &tensor, The result of this step is a list of bounding boxes that contain text. */ - std::span tensorData(tensor.const_data_ptr(), - tensor.numel()); + std::span tensorData(tensor.const_data_ptr(), tensor.numel()); /* The output of the model is a matrix half the size of the input image containing two channels representing the heatmaps. */ auto [scoreTextMat, scoreAffinityMat] = utils::interleavedArrayToMats( - tensorData, - cv::Size(modelInputSize.width / 2, modelInputSize.height / 2)); - float txtThreshold = detectSingleCharacters - ? constants::kTextThreshold - : constants::kTextThresholdVertical; + tensorData, cv::Size(modelInputSize.width / 2, modelInputSize.height / 2)); + float txtThreshold = + detectSingleCharacters ? constants::kTextThreshold : constants::kTextThresholdVertical; std::vector bBoxesList = - utils::getDetBoxesFromTextMapVertical( - scoreTextMat, scoreAffinityMat, txtThreshold, - constants::kLinkThreshold, detectSingleCharacters); - const float restoreRatio = utils::calculateRestoreRatio( - scoreTextMat.rows, constants::kRecognizerImageSize); + utils::getDetBoxesFromTextMapVertical(scoreTextMat, scoreAffinityMat, txtThreshold, + constants::kLinkThreshold, detectSingleCharacters); + const float restoreRatio = + utils::calculateRestoreRatio(scoreTextMat.rows, constants::kRecognizerImageSize); utils::restoreBboxRatio(bBoxesList, restoreRatio); // if this is Narrow Detector, do not group boxes. if (!detectSingleCharacters) { - bBoxesList = utils::groupTextBoxes( - bBoxesList, constants::kCenterThreshold, constants::kDistanceThreshold, - constants::kHeightThreshold, constants::kMinSideThreshold, - constants::kMaxSideThreshold, constants::kMaxWidth); + bBoxesList = utils::groupTextBoxes(bBoxesList, constants::kCenterThreshold, + constants::kDistanceThreshold, constants::kHeightThreshold, + constants::kMinSideThreshold, constants::kMaxSideThreshold, + constants::kMaxWidth); } return bBoxesList; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.h b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.h index 34b79ce33a..61838c67ad 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.h @@ -39,12 +39,10 @@ class VerticalDetector final : public Detector { public: explicit VerticalDetector(const std::string &modelSource, std::shared_ptr callInvoker); - std::vector generate(const cv::Mat &inputImage, - int32_t inputWidth) override; + std::vector generate(const cv::Mat &inputImage, int32_t inputWidth) override; private: - std::vector - postprocess(const Tensor &tensor, const cv::Size &modelInputSize, - bool detectSingleCharacters) const; + std::vector postprocess(const Tensor &tensor, const cv::Size &modelInputSize, + bool detectSingleCharacters) const; }; } // namespace rnexecutorch::models::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp index c0a531ecd3..c20e5b10d3 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp @@ -11,13 +11,11 @@ #include namespace rnexecutorch::models::ocr { -VerticalOCR::VerticalOCR(const std::string &detectorSource, - const std::string &recognizerSource, +VerticalOCR::VerticalOCR(const std::string &detectorSource, const std::string &recognizerSource, std::string symbols, bool independentChars, std::shared_ptr invoker) - : detector(detectorSource, invoker), recognizer(recognizerSource, invoker), - converter(symbols), independentCharacters(independentChars), - callInvoker(invoker) {} + : detector(detectorSource, invoker), recognizer(recognizerSource, invoker), converter(symbols), + independentCharacters(independentChars), callInvoker(invoker) {} std::vector VerticalOCR::runInference(cv::Mat image) { std::scoped_lock lock(inference_mutex_); @@ -26,10 +24,8 @@ std::vector VerticalOCR::runInference(cv::Mat image) { std::vector largeBoxes = detector.generate(image, constants::kLargeDetectorWidth); - cv::Size largeDetectorSize = - detector.calculateModelImageSize(constants::kLargeDetectorWidth); - cv::Mat resizedImage = - image_processing::resizePadded(image, largeDetectorSize); + cv::Size largeDetectorSize = detector.calculateModelImageSize(constants::kLargeDetectorWidth); + cv::Mat resizedImage = image_processing::resizePadded(image, largeDetectorSize); types::PaddingInfo imagePaddings = utils::calculateResizeRatioAndPaddings(image.size(), largeDetectorSize); @@ -37,15 +33,13 @@ std::vector VerticalOCR::runInference(cv::Mat image) { predictions.reserve(largeBoxes.size()); for (auto &box : largeBoxes) { - predictions.push_back( - _processSingleTextBox(box, image, resizedImage, imagePaddings)); + predictions.push_back(_processSingleTextBox(box, image, resizedImage, imagePaddings)); } return predictions; } -std::vector -VerticalOCR::generateFromString(std::string input) { +std::vector VerticalOCR::generateFromString(std::string input) { cv::Mat image = image_processing::readImage(input); if (image.empty()) { throw RnExecutorchError(RnExecutorchErrorCode::FileReadFailed, @@ -54,9 +48,8 @@ VerticalOCR::generateFromString(std::string input) { return runInference(image); } -std::vector -VerticalOCR::generateFromFrame(jsi::Runtime &runtime, - const jsi::Value &frameData) { +std::vector VerticalOCR::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &frameData) { auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData); cv::Mat frame = ::rnexecutorch::utils::frameToMat(runtime, frameData); cv::Mat bgr; @@ -65,28 +58,23 @@ VerticalOCR::generateFromFrame(jsi::Runtime &runtime, #elif defined(__ANDROID__) cv::cvtColor(frame, bgr, cv::COLOR_RGBA2BGR); #else - throw RnExecutorchError( - RnExecutorchErrorCode::PlatformNotSupported, - "generateFromFrame is not supported on this platform"); + throw RnExecutorchError(RnExecutorchErrorCode::PlatformNotSupported, + "generateFromFrame is not supported on this platform"); #endif cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(bgr, orient); auto detections = runInference(rotated); for (auto &det : detections) { std::array corners = {det.bbox.p1, det.bbox.p2}; ::rnexecutorch::utils::inverseRotatePoints(corners, orient, rotated.size()); - det.bbox = {{std::min(corners[0].x, corners[1].x), - std::min(corners[0].y, corners[1].y)}, - {std::max(corners[0].x, corners[1].x), - std::max(corners[0].y, corners[1].y)}}; + det.bbox = {{std::min(corners[0].x, corners[1].x), std::min(corners[0].y, corners[1].y)}, + {std::max(corners[0].x, corners[1].x), std::max(corners[0].y, corners[1].y)}}; } return detections; } -std::vector -VerticalOCR::generateFromPixels(JSTensorViewIn pixelData) { +std::vector VerticalOCR::generateFromPixels(JSTensorViewIn pixelData) { cv::Mat image; - cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, - cv::COLOR_RGB2BGR); + cv::cvtColor(::rnexecutorch::utils::pixelsToMat(pixelData), image, cv::COLOR_RGB2BGR); return runInference(image); } @@ -97,8 +85,7 @@ std::size_t VerticalOCR::getMemoryLowerBound() const noexcept { // Strategy 1: Recognize each character individually std::pair VerticalOCR::_handleIndependentCharacters( const types::DetectorBBox &box, const cv::Mat &originalImage, - const std::vector &characterBoxes, - const types::PaddingInfo &paddingsBox, + const std::vector &characterBoxes, const types::PaddingInfo &paddingsBox, const types::PaddingInfo &imagePaddings) { std::string text; float confidenceScore = 0.0f; @@ -112,19 +99,18 @@ std::pair VerticalOCR::_handleIndependentCharacters( 3. Resize it to [VerticalSmallRecognizerWidth x RecognizerHeight] (64 x 64), */ - auto croppedChar = utils::prepareForRecognition( - originalImage, characterBox.bbox, box.bbox, paddingsBox, imagePaddings); + auto croppedChar = utils::prepareForRecognition(originalImage, characterBox.bbox, box.bbox, + paddingsBox, imagePaddings); /* To make Recognition simpler, we convert cropped character image to a bit mask with white character and black background. */ croppedChar = utils::characterBitMask(croppedChar); - croppedChar = utils::normalizeForRecognizer( - croppedChar, constants::kRecognizerHeight, 0.0, true); + croppedChar = + utils::normalizeForRecognizer(croppedChar, constants::kRecognizerHeight, 0.0, true); - const auto &[predIndex, score] = - recognizer.generate(croppedChar, constants::kRecognizerHeight); + const auto &[predIndex, score] = recognizer.generate(croppedChar, constants::kRecognizerHeight); if (!predIndex.empty()) { text += converter.decodeGreedy(predIndex, predIndex.size())[0]; } @@ -135,11 +121,11 @@ std::pair VerticalOCR::_handleIndependentCharacters( } // Strategy 2: Concatenate characters and recognize as a single line -std::pair VerticalOCR::_handleJointCharacters( - const types::DetectorBBox &box, const cv::Mat &originalImage, - const std::vector &characterBoxes, - const types::PaddingInfo &paddingsBox, - const types::PaddingInfo &imagePaddings) { +std::pair +VerticalOCR::_handleJointCharacters(const types::DetectorBBox &box, const cv::Mat &originalImage, + const std::vector &characterBoxes, + const types::PaddingInfo &paddingsBox, + const types::PaddingInfo &imagePaddings) { std::string text; std::vector croppedCharacters; croppedCharacters.reserve(characterBoxes.size()); @@ -152,18 +138,17 @@ std::pair VerticalOCR::_handleJointCharacters( 64). The same height is required for horizontal concatenation of single characters into one image. */ - auto croppedChar = utils::prepareForRecognition( - originalImage, characterBox.bbox, box.bbox, paddingsBox, imagePaddings); + auto croppedChar = utils::prepareForRecognition(originalImage, characterBox.bbox, box.bbox, + paddingsBox, imagePaddings); croppedCharacters.push_back(croppedChar); } cv::Mat mergedCharacters; cv::hconcat(croppedCharacters, mergedCharacters); mergedCharacters = image_processing::resizePadded( - mergedCharacters, - cv::Size(constants::kLargeRecognizerWidth, constants::kRecognizerHeight)); - mergedCharacters = utils::normalizeForRecognizer( - mergedCharacters, constants::kRecognizerHeight, 0.0, false); + mergedCharacters, cv::Size(constants::kLargeRecognizerWidth, constants::kRecognizerHeight)); + mergedCharacters = + utils::normalizeForRecognizer(mergedCharacters, constants::kRecognizerHeight, 0.0, false); const auto &[predIndex, confidenceScore] = recognizer.generate(mergedCharacters, constants::kLargeRecognizerWidth); @@ -173,15 +158,15 @@ std::pair VerticalOCR::_handleJointCharacters( return {text, confidenceScore}; } -types::OCRDetection VerticalOCR::_processSingleTextBox( - types::DetectorBBox &box, const cv::Mat &originalImage, - const cv::Mat &resizedLargeImage, const types::PaddingInfo &imagePaddings) { +types::OCRDetection VerticalOCR::_processSingleTextBox(types::DetectorBBox &box, + const cv::Mat &originalImage, + const cv::Mat &resizedLargeImage, + const types::PaddingInfo &imagePaddings) { cv::Rect boundingBox = utils::extractBoundingBox(box.bbox); // Crop the image for detection of single characters. - cv::Rect safeRect = - boundingBox & cv::Rect(0, 0, resizedLargeImage.cols, - resizedLargeImage.rows); // ensure valid box + cv::Rect safeRect = boundingBox & cv::Rect(0, 0, resizedLargeImage.cols, + resizedLargeImage.rows); // ensure valid box cv::Mat croppedLargeBox = resizedLargeImage(safeRect); // 2. Narrow Detector - detects single characters @@ -196,16 +181,15 @@ types::OCRDetection VerticalOCR::_processSingleTextBox( const int32_t boxHeight = static_cast(box.bbox.height()); cv::Size narrowRecognizerSize = detector.calculateModelImageSize(constants::kSmallDetectorWidth); - types::PaddingInfo paddingsBox = utils::calculateResizeRatioAndPaddings( - cv::Size(boxWidth, boxHeight), narrowRecognizerSize); + types::PaddingInfo paddingsBox = + utils::calculateResizeRatioAndPaddings(cv::Size(boxWidth, boxHeight), narrowRecognizerSize); // 3. Recognition - decide between Strategy 1 and Strategy 2. std::tie(text, confidenceScore) = - independentCharacters - ? _handleIndependentCharacters(box, originalImage, characterBoxes, - paddingsBox, imagePaddings) - : _handleJointCharacters(box, originalImage, characterBoxes, - paddingsBox, imagePaddings); + independentCharacters ? _handleIndependentCharacters(box, originalImage, characterBoxes, + paddingsBox, imagePaddings) + : _handleJointCharacters(box, originalImage, characterBoxes, + paddingsBox, imagePaddings); } // Modify the returned boxes to match the original image size. const float ratio = imagePaddings.resizeRatio; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h index 4016e28138..d5256ac8d9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h @@ -45,9 +45,8 @@ using executorch::extension::TensorPtr; class VerticalOCR final { public: - explicit VerticalOCR(const std::string &detectorSource, - const std::string &recognizerSource, std::string symbols, - bool indpendentCharacters, + explicit VerticalOCR(const std::string &detectorSource, const std::string &recognizerSource, + std::string symbols, bool indpendentCharacters, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector generateFromString(std::string input); @@ -61,22 +60,20 @@ class VerticalOCR final { private: std::vector runInference(cv::Mat image); - std::pair _handleIndependentCharacters( - const types::DetectorBBox &box, const cv::Mat &originalImage, - const std::vector &characterBoxes, - const types::PaddingInfo &paddingsBox, - const types::PaddingInfo &imagePaddings); + std::pair + _handleIndependentCharacters(const types::DetectorBBox &box, const cv::Mat &originalImage, + const std::vector &characterBoxes, + const types::PaddingInfo &paddingsBox, + const types::PaddingInfo &imagePaddings); std::pair - _handleJointCharacters(const types::DetectorBBox &box, - const cv::Mat &originalImage, + _handleJointCharacters(const types::DetectorBBox &box, const cv::Mat &originalImage, const std::vector &characterBoxes, const types::PaddingInfo &paddingsBox, const types::PaddingInfo &imagePaddings); - types::OCRDetection - _processSingleTextBox(types::DetectorBBox &box, const cv::Mat &originalImage, - const cv::Mat &resizedLargeImage, - const types::PaddingInfo &imagePaddings); + types::OCRDetection _processSingleTextBox(types::DetectorBBox &box, const cv::Mat &originalImage, + const cv::Mat &resizedLargeImage, + const types::PaddingInfo &imagePaddings); VerticalDetector detector; Recognizer recognizer; @@ -87,6 +84,6 @@ class VerticalOCR final { }; } // namespace models::ocr -REGISTER_CONSTRUCTOR(models::ocr::VerticalOCR, std::string, std::string, - std::string, bool, std::shared_ptr); +REGISTER_CONSTRUCTOR(models::ocr::VerticalOCR, std::string, std::string, std::string, bool, + std::shared_ptr); } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h index a46947572f..62456ef4f6 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Constants.h @@ -24,8 +24,7 @@ inline constexpr auto kSpeechThreshold = 0.6f; inline constexpr size_t kMinSpeechDurationMs = 250; inline constexpr size_t kMinSilenceDurationMs = 100; inline constexpr size_t kSpeechPadMs = 30; -inline constexpr size_t kStreamBufferMaxSize = 10 * kSampleRate; // 10s -inline constexpr size_t kStreamBufferMinReserve = - 1 * kSampleRate; // 1s of audio +inline constexpr size_t kStreamBufferMaxSize = 10 * kSampleRate; // 10s +inline constexpr size_t kStreamBufferMinReserve = 1 * kSampleRate; // 1s of audio } // namespace rnexecutorch::models::voice_activity_detection::constants diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Types.h index 51794d6bf5..2a68efbd5a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Types.h @@ -9,4 +9,4 @@ struct Segment { size_t end; }; -} // namespace rnexecutorch::models::voice_activity_detection::types \ No newline at end of file +} // namespace rnexecutorch::models::voice_activity_detection::types diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp index ff26372896..57286bc6d3 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.cpp @@ -4,9 +4,8 @@ #include namespace rnexecutorch::models::voice_activity_detection::utils { -size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor, - size_t numClass, size_t size, - std::vector &resultVector, +size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor, size_t numClass, + size_t size, std::vector &resultVector, size_t startIdx) { const auto *rawData = tensor.const_data_ptr(); for (size_t i = 0; i < size; i++) { @@ -15,8 +14,8 @@ size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor, return startIdx + size; } -std::vector -mergeSegments(const std::vector &segments, size_t maxMergeGap) { +std::vector mergeSegments(const std::vector &segments, + size_t maxMergeGap) { if (segments.empty()) { return segments; } @@ -28,8 +27,7 @@ mergeSegments(const std::vector &segments, size_t maxMergeGap) { auto &lastMerged = mergedSegments.back(); const auto ¤t = segments[i]; - if (current.start < lastMerged.end || - current.start - lastMerged.end <= maxMergeGap) { + if (current.start < lastMerged.end || current.start - lastMerged.end <= maxMergeGap) { lastMerged.end = std::max(lastMerged.end, current.end); } else { mergedSegments.push_back(current); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h index 3ec8212448..5eb7ae7397 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/Utils.h @@ -6,9 +6,8 @@ #include namespace rnexecutorch::models::voice_activity_detection::utils { -size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor, - size_t numClass, size_t size, - std::vector &resultVector, +size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor, size_t numClass, + size_t size, std::vector &resultVector, size_t startIdx); /** @@ -20,7 +19,7 @@ size_t getNonSpeechClassProbabilites(const executorch::aten::Tensor &tensor, * samples) to qualify them for a merge. * @return A new collection containing the merged speech segments. */ -std::vector -mergeSegments(const std::vector &segments, size_t maxMergeGap); +std::vector mergeSegments(const std::vector &segments, + size_t maxMergeGap); } // namespace rnexecutorch::models::voice_activity_detection::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp index 77c3a16cc8..9e060440bb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.cpp @@ -18,64 +18,58 @@ using namespace constants; using executorch::aten::Tensor; using executorch::extension::TensorPtr; -VoiceActivityDetection::VoiceActivityDetection( - const std::string &modelSource, - std::shared_ptr callInvoker) +VoiceActivityDetection::VoiceActivityDetection(const std::string &modelSource, + std::shared_ptr callInvoker) : BaseModel(modelSource, callInvoker), callInvoker_(callInvoker) { // Important - preallocate memory for the buffer to avoid any reallocations. audioBuffer_.reserve(2 * constants::kStreamBufferMaxSize); } -std::vector -VoiceActivityDetection::generate(std::span waveform, - uint32_t mergeGap) const { +std::vector VoiceActivityDetection::generate(std::span waveform, + uint32_t mergeGap) const { // Guard against small buffers to prevent underflow in preprocess if (waveform.size() < kWindowSize) { return {}; } auto windowedInput = preprocess(waveform); - auto [chunksNumber, remainder] = std::div( - static_cast(windowedInput.size()), static_cast(kModelInputMax)); + auto [chunksNumber, remainder] = + std::div(static_cast(windowedInput.size()), static_cast(kModelInputMax)); std::vector scores(windowedInput.size()); auto lastChunkSize = remainder; if (remainder < kModelInputMin) { auto paddingSize = kModelInputMin - remainder; lastChunkSize = kModelInputMin; - windowedInput.insert(windowedInput.end(), paddingSize, - std::array{}); + windowedInput.insert(windowedInput.end(), paddingSize, std::array{}); } TensorPtr inputTensor; size_t startIdx = 0; for (size_t i = 0; i < chunksNumber; i++) { - std::span> chunk( - windowedInput.data() + kModelInputMax * i, kModelInputMax); + std::span> chunk(windowedInput.data() + kModelInputMax * i, + kModelInputMax); inputTensor = executorch::extension::from_blob( - chunk.data(), {kModelInputMax, kPaddedWindowSize}, - executorch::aten::ScalarType::Float); + chunk.data(), {kModelInputMax, kPaddedWindowSize}, executorch::aten::ScalarType::Float); auto forwardResult = BaseModel::forward(inputTensor); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); auto tensor = forwardResult->at(0).toTensor(); - startIdx = utils::getNonSpeechClassProbabilites( - tensor, tensor.size(2), tensor.size(1), scores, startIdx); + startIdx = utils::getNonSpeechClassProbabilites(tensor, tensor.size(2), tensor.size(1), scores, + startIdx); } std::span> lastChunk( windowedInput.data() + kModelInputMax * chunksNumber, lastChunkSize); inputTensor = executorch::extension::from_blob( - lastChunk.data(), {lastChunkSize, kPaddedWindowSize}, - executorch::aten::ScalarType::Float); + lastChunk.data(), {lastChunkSize, kPaddedWindowSize}, executorch::aten::ScalarType::Float); auto forwardResult = BaseModel::forward(inputTensor); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); auto tensor = forwardResult->at(0).toTensor(); - startIdx = utils::getNonSpeechClassProbabilites(tensor, tensor.size(2), - remainder, scores, startIdx); + startIdx = + utils::getNonSpeechClassProbabilites(tensor, tensor.size(2), remainder, scores, startIdx); return postprocess(scores, kSpeechThreshold, mergeGap); } -void VoiceActivityDetection::stream(std::shared_ptr callback, - uint32_t timeout, +void VoiceActivityDetection::stream(std::shared_ptr callback, uint32_t timeout, uint32_t detectionMargin) { bool expected = false; if (!isStreaming_.compare_exchange_strong(expected, true)) { @@ -97,9 +91,8 @@ void VoiceActivityDetection::stream(std::shared_ptr callback, // where true corresponds to detected ongoing speech, and false corresponds to // silence. auto nativeCallback = [this, callback](bool speaking) { - callInvoker_->invokeAsync([callback, speaking](jsi::Runtime &rt) { - callback->call(rt, jsi::Value(speaking)); - }); + callInvoker_->invokeAsync( + [callback, speaking](jsi::Runtime &rt) { callback->call(rt, jsi::Value(speaking)); }); }; while (isStreaming_) { @@ -135,8 +128,7 @@ void VoiceActivityDetection::stream(std::shared_ptr callback, auto speechEnd = lastSegment.end; std::scoped_lock lock(audioBufferMutex_); - uint32_t diffMs = - (audioBuffer_.size() - speechEnd) / constants::kSamplesPerMs; // [ms] + uint32_t diffMs = (audioBuffer_.size() - speechEnd) / constants::kSamplesPerMs; // [ms] speaking = diffMs <= detectionMargin; } @@ -175,27 +167,24 @@ VoiceActivityDetection::preprocess(std::span waveform) const { auto windowView = waveform.subspan(i * kHopLength, kWindowSize); ranges::copy(windowView, frameBuffer[i].begin() + leftPadding); - auto frameView = - std::span{frameBuffer[i].data() + leftPadding, kWindowSize}; + auto frameView = std::span{frameBuffer[i].data() + leftPadding, kWindowSize}; const float sum = std::reduce(frameView.begin(), frameView.end(), 0.0f); const float mean = sum / kWindowSize; - ranges::transform(frameView, frameView.begin(), - [mean](float value) { return value - mean; }); + ranges::transform(frameView, frameView.begin(), [mean](float value) { return value - mean; }); // apply pre-emphasis filter for (auto j = frameView.size() - 1; j > 0; --j) { frameView[j] -= kPreemphasisCoeff * frameView[j - 1]; } // apply hamming window to reduce spectral leakage - ranges::transform(frameView, kHammingWindowArray, frameView.begin(), - std::multiplies{}); + ranges::transform(frameView, kHammingWindowArray, frameView.begin(), std::multiplies{}); } return frameBuffer; } -std::vector -VoiceActivityDetection::postprocess(const std::vector &scores, - float threshold, uint32_t mergeGap) const { +std::vector VoiceActivityDetection::postprocess(const std::vector &scores, + float threshold, + uint32_t mergeGap) const { bool triggered = false; std::vector speechSegments{}; ssize_t startSegment = -1; diff --git a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h index d9d72adbba..3e905ddb3b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h @@ -38,8 +38,7 @@ class VoiceActivityDetection : public BaseModel { * @param detectionMargin Specifies (in miliseconds) how far the last detected * speech segment can be to still be considered as ongoing speech. */ - void stream(std::shared_ptr callback, uint32_t timeout, - uint32_t detectionMargin); + void stream(std::shared_ptr callback, uint32_t timeout, uint32_t detectionMargin); /** * When called, stops the streaming procedure. @@ -54,8 +53,7 @@ class VoiceActivityDetection : public BaseModel { private: std::vector> preprocess(std::span waveform) const; - std::vector postprocess(const std::vector &scores, - float threshold, + std::vector postprocess(const std::vector &scores, float threshold, uint32_t mergeGap) const; std::shared_ptr callInvoker_; @@ -71,6 +69,6 @@ class VoiceActivityDetection : public BaseModel { } // namespace models::voice_activity_detection -REGISTER_CONSTRUCTOR(models::voice_activity_detection::VoiceActivityDetection, - std::string, std::shared_ptr); -} // namespace rnexecutorch \ No newline at end of file +REGISTER_CONSTRUCTOR(models::voice_activity_detection::VoiceActivityDetection, std::string, + std::shared_ptr); +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTest.cpp index 4889537389..3d621d887c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTest.cpp @@ -12,8 +12,7 @@ using namespace executorch::extension; using namespace model_tests; using executorch::runtime::EValue; -constexpr auto kValidStyleTransferModelPath = - "style_transfer_candy_xnnpack_fp32.pte"; +constexpr auto kValidStyleTransferModelPath = "style_transfer_candy_xnnpack_fp32.pte"; // ============================================================================ // Common tests via typed test suite @@ -23,13 +22,9 @@ namespace model_tests { template <> struct ModelTraits { using ModelType = BaseModel; - static ModelType createValid() { - return ModelType(kValidStyleTransferModelPath, nullptr); - } + static ModelType createValid() { return ModelType(kValidStyleTransferModelPath, nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", nullptr); } static void callGenerate(ModelType &model) { std::vector shape = {1, 3, 640, 640}; @@ -58,15 +53,13 @@ TEST(BaseModelGetInputShapeTests, ValidMethodCorrectShape) { TEST(BaseModelGetInputShapeTests, InvalidMethodThrows) { BaseModel model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.getInputShape("this_method_does_not_exist", 0), - RnExecutorchError); + EXPECT_THROW((void)model.getInputShape("this_method_does_not_exist", 0), RnExecutorchError); } TEST(BaseModelGetInputShapeTests, ValidMethodInvalidIndexThrows) { BaseModel model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW( - (void)model.getInputShape("forward", std::numeric_limits::min()), - RnExecutorchError); + EXPECT_THROW((void)model.getInputShape("forward", std::numeric_limits::min()), + RnExecutorchError); } TEST(BaseModelGetAllInputShapesTests, ValidMethodReturnsShapes) { @@ -79,8 +72,7 @@ TEST(BaseModelGetAllInputShapesTests, ValidMethodReturnsShapes) { TEST(BaseModelGetAllInputShapesTests, InvalidMethodThrows) { BaseModel model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW(model.getAllInputShapes("non_existent_method"), - RnExecutorchError); + EXPECT_THROW(model.getAllInputShapes("non_existent_method"), RnExecutorchError); } TEST(BaseModelGetMethodMetaTests, ValidMethodReturnsOk) { diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTests.h b/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTests.h index af00a2164d..1660ab941f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTests.h +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/BaseModelTests.h @@ -114,10 +114,9 @@ TYPED_TEST_P(CommonModelTest, MultipleGeneratesWork) { // TODO: Investigate why TextToImage fails on MultipleGeneratesWork in the // emulator environment -REGISTER_TYPED_TEST_SUITE_P(CommonModelTest, InvalidPathThrows, - ValidPathDoesntThrow, GetMemoryLowerBoundValue, - GetMemoryLowerBoundConsistent, UnloadDoesntThrow, - MultipleUnloadsSafe, GenerateAfterUnloadThrows, +REGISTER_TYPED_TEST_SUITE_P(CommonModelTest, InvalidPathThrows, ValidPathDoesntThrow, + GetMemoryLowerBoundValue, GetMemoryLowerBoundConsistent, + UnloadDoesntThrow, MultipleUnloadsSafe, GenerateAfterUnloadThrows, MultipleGeneratesWork); } // namespace model_tests diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp index 6991c2fe06..dbf455df0d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp @@ -10,8 +10,7 @@ using namespace rnexecutorch::models::classification; using namespace model_tests; constexpr auto kValidClassificationModelPath = "efficientnet_v2_s_xnnpack.pte"; -constexpr auto kValidTestImagePath = - "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; constexpr size_t kImagenet1kNumClasses = 1000; static std::vector kImagenetNormMean = {0.485f, 0.456f, 0.406f}; @@ -34,13 +33,11 @@ template <> struct ModelTraits { using ModelType = Classification; static ModelType createValid() { - return ModelType(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + return ModelType(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", {}, {}, {}, nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", {}, {}, {}, nullptr); } static void callGenerate(ModelType &model) { (void)model.generateFromString(kValidTestImagePath); @@ -49,52 +46,48 @@ template <> struct ModelTraits { } // namespace model_tests using ClassificationTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(Classification, CommonModelTest, - ClassificationTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(Classification, VisionModelTest, - ClassificationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(Classification, CommonModelTest, ClassificationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(Classification, VisionModelTest, ClassificationTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(ClassificationGenerateTests, InvalidImagePathThrows) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), - RnExecutorchError); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(ClassificationGenerateTests, EmptyImagePathThrows) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(ClassificationGenerateTests, MalformedURIThrows) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), - RnExecutorchError); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(ClassificationGenerateTests, ValidImageReturnsResults) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); auto results = model.generateFromString(kValidTestImagePath); EXPECT_FALSE(results.empty()); } TEST(ClassificationGenerateTests, ResultsHaveCorrectSize) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); auto results = model.generateFromString(kValidTestImagePath); auto expectedNumClasses = kImagenet1kNumClasses; EXPECT_EQ(results.size(), expectedNumClasses); } TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); auto results = model.generateFromString(kValidTestImagePath); float sum = 0.0f; @@ -107,8 +100,8 @@ TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) { } TEST(ClassificationGenerateTests, TopPredictionHasReasonableConfidence) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); auto results = model.generateFromString(kValidTestImagePath); float maxProb = 0.0f; @@ -121,15 +114,14 @@ TEST(ClassificationGenerateTests, TopPredictionHasReasonableConfidence) { } TEST(ClassificationGenerateTests, WrongLabelCountThrows) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, {"A", "B", "C"}, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath), - RnExecutorchError); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + {"A", "B", "C"}, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath), RnExecutorchError); } TEST(ClassificationInheritedTests, GetInputShapeWorks) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); auto shape = model.getInputShape("forward", 0); EXPECT_EQ(shape.size(), 4); EXPECT_EQ(shape[0], 1); @@ -137,15 +129,15 @@ TEST(ClassificationInheritedTests, GetInputShapeWorks) { } TEST(ClassificationInheritedTests, GetAllInputShapesWorks) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); auto shapes = model.getAllInputShapes("forward"); EXPECT_FALSE(shapes.empty()); } TEST(ClassificationInheritedTests, GetMethodMetaWorks) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } @@ -154,11 +146,10 @@ TEST(ClassificationInheritedTests, GetMethodMetaWorks) { // generateFromPixels smoke test // ============================================================================ TEST(ClassificationPixelTests, ValidPixelsReturnsResults) { - Classification model(kValidClassificationModelPath, kImagenetNormMean, - kImagenetNormStd, getImagenetLabelNames(), nullptr); + Classification model(kValidClassificationModelPath, kImagenetNormMean, kImagenetNormStd, + getImagenetLabelNames(), nullptr); std::vector buf(64 * 64 * 3, 128); - JSTensorViewIn view{ - buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + JSTensorViewIn view{buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; auto results = model.generateFromPixels(view); EXPECT_FALSE(results.empty()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp index 4982206614..45793e415b 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp @@ -11,10 +11,8 @@ using namespace rnexecutorch; using namespace rnexecutorch::models::embeddings; using namespace model_tests; -constexpr auto kValidImageEmbeddingsModelPath = - "clip-vit-base-patch32-vision_xnnpack.pte"; -constexpr auto kValidTestImagePath = - "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +constexpr auto kValidImageEmbeddingsModelPath = "clip-vit-base-patch32-vision_xnnpack.pte"; +constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; // ============================================================================ // Common tests via typed test suite @@ -23,13 +21,9 @@ namespace model_tests { template <> struct ModelTraits { using ModelType = ImageEmbeddings; - static ModelType createValid() { - return ModelType(kValidImageEmbeddingsModelPath, nullptr); - } + static ModelType createValid() { return ModelType(kValidImageEmbeddingsModelPath, nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", nullptr); } static void callGenerate(ModelType &model) { (void)model.generateFromString(kValidTestImagePath); @@ -38,18 +32,15 @@ template <> struct ModelTraits { } // namespace model_tests using ImageEmbeddingsTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, CommonModelTest, - ImageEmbeddingsTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, VisionModelTest, - ImageEmbeddingsTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, CommonModelTest, ImageEmbeddingsTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, VisionModelTest, ImageEmbeddingsTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(ImageEmbeddingsGenerateTests, InvalidImagePathThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, EmptyImagePathThrows) { @@ -59,8 +50,7 @@ TEST(ImageEmbeddingsGenerateTests, EmptyImagePathThrows) { TEST(ImageEmbeddingsGenerateTests, MalformedURIThrows) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(ImageEmbeddingsGenerateTests, ValidImageReturnsResults) { @@ -134,8 +124,7 @@ TEST(ImageEmbeddingsInheritedTests, GetMethodMetaWorks) { TEST(ImageEmbeddingsPixelTests, ValidPixelsReturnsEmbedding) { ImageEmbeddings model(kValidImageEmbeddingsModelPath, nullptr); std::vector buf(64 * 64 * 3, 128); - JSTensorViewIn view{ - buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + JSTensorViewIn view{buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; auto result = model.generateFromPixels(view); EXPECT_NE(result, nullptr); EXPECT_GT(result->size(), 0u); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp index ff003eb62d..f5aac3e9f4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp @@ -27,105 +27,87 @@ template <> struct ModelTraits { return ModelType(kValidInstanceSegModelPath, {}, {}, true, nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", {}, {}, true, nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", {}, {}, true, nullptr); } static void callGenerate(ModelType &model) { - (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, 100, {}, true, - kMethodName); + (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, 100, {}, true, kMethodName); } }; } // namespace model_tests using InstanceSegmentationTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(InstanceSegmentation, CommonModelTest, - InstanceSegmentationTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(InstanceSegmentation, VisionModelTest, - InstanceSegmentationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(InstanceSegmentation, CommonModelTest, InstanceSegmentationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(InstanceSegmentation, VisionModelTest, InstanceSegmentationTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(InstanceSegGenerateTests, InvalidImagePathThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5, - 100, {}, true, kMethodName), - RnExecutorchError); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + EXPECT_THROW( + (void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5, 100, {}, true, kMethodName), + RnExecutorchError); } TEST(InstanceSegGenerateTests, EmptyImagePathThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - EXPECT_THROW( - (void)model.generateFromString("", 0.5, 0.5, 100, {}, true, kMethodName), - RnExecutorchError); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + EXPECT_THROW((void)model.generateFromString("", 0.5, 0.5, 100, {}, true, kMethodName), + RnExecutorchError); } TEST(InstanceSegGenerateTests, EmptyMethodNameThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, - 100, {}, true, ""), + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, 100, {}, true, ""), RnExecutorchError); } TEST(InstanceSegGenerateTests, NegativeConfidenceThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.5, - 100, {}, true, kMethodName), - RnExecutorchError); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + EXPECT_THROW( + (void)model.generateFromString(kValidTestImagePath, -0.1, 0.5, 100, {}, true, kMethodName), + RnExecutorchError); } TEST(InstanceSegGenerateTests, ConfidenceAboveOneThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.5, - 100, {}, true, kMethodName), - RnExecutorchError); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + EXPECT_THROW( + (void)model.generateFromString(kValidTestImagePath, 1.1, 0.5, 100, {}, true, kMethodName), + RnExecutorchError); } TEST(InstanceSegGenerateTests, NegativeIouThresholdThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, -0.1, - 100, {}, true, kMethodName), - RnExecutorchError); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + EXPECT_THROW( + (void)model.generateFromString(kValidTestImagePath, 0.5, -0.1, 100, {}, true, kMethodName), + RnExecutorchError); } TEST(InstanceSegGenerateTests, IouThresholdAboveOneThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 1.1, - 100, {}, true, kMethodName), - RnExecutorchError); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + EXPECT_THROW( + (void)model.generateFromString(kValidTestImagePath, 0.5, 1.1, 100, {}, true, kMethodName), + RnExecutorchError); } TEST(InstanceSegGenerateTests, ValidImageReturnsResults) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); EXPECT_FALSE(results.empty()); } TEST(InstanceSegGenerateTests, HighThresholdReturnsFewerResults) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto lowResults = model.generateFromString(kValidTestImagePath, 0.1, 0.5, 100, - {}, true, kMethodName); - auto highResults = model.generateFromString(kValidTestImagePath, 0.9, 0.5, - 100, {}, true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto lowResults = + model.generateFromString(kValidTestImagePath, 0.1, 0.5, 100, {}, true, kMethodName); + auto highResults = + model.generateFromString(kValidTestImagePath, 0.9, 0.5, 100, {}, true, kMethodName); EXPECT_GE(lowResults.size(), highResults.size()); } TEST(InstanceSegGenerateTests, MaxInstancesLimitsResults) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.1, 0.5, 2, {}, - true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto results = model.generateFromString(kValidTestImagePath, 0.1, 0.5, 2, {}, true, kMethodName); EXPECT_LE(results.size(), 2u); } @@ -133,10 +115,9 @@ TEST(InstanceSegGenerateTests, MaxInstancesLimitsResults) { // Result validation tests // ============================================================================ TEST(InstanceSegResultTests, InstancesHaveValidBoundingBoxes) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); for (const auto &inst : results) { EXPECT_LE(inst.bbox.p1.x, inst.bbox.p2.x); @@ -147,10 +128,9 @@ TEST(InstanceSegResultTests, InstancesHaveValidBoundingBoxes) { } TEST(InstanceSegResultTests, InstancesHaveValidScores) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); for (const auto &inst : results) { EXPECT_GE(inst.score, 0.0f); @@ -159,16 +139,14 @@ TEST(InstanceSegResultTests, InstancesHaveValidScores) { } TEST(InstanceSegResultTests, InstancesHaveValidMasks) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); for (const auto &inst : results) { EXPECT_GT(inst.maskWidth, 0); EXPECT_GT(inst.maskHeight, 0); - EXPECT_EQ(inst.mask->size(), - static_cast(inst.maskWidth) * inst.maskHeight); + EXPECT_EQ(inst.mask->size(), static_cast(inst.maskWidth) * inst.maskHeight); for (size_t i = 0; i < inst.mask->size(); ++i) { uint8_t val = inst.mask->data()[i]; @@ -178,10 +156,9 @@ TEST(InstanceSegResultTests, InstancesHaveValidMasks) { } TEST(InstanceSegResultTests, InstancesHaveValidClassIndices) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); for (const auto &inst : results) { EXPECT_GE(inst.classIndex, 0); @@ -193,12 +170,11 @@ TEST(InstanceSegResultTests, InstancesHaveValidClassIndices) { // Class filtering tests // ============================================================================ TEST(InstanceSegFilterTests, ClassFilterReturnsOnlyMatchingClasses) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); // Filter to class index 0 (PERSON in CocoLabelYolo) std::vector classIndices = {0}; - auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - classIndices, true, kMethodName); + auto results = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, classIndices, true, kMethodName); for (const auto &inst : results) { EXPECT_EQ(inst.classIndex, 0); @@ -206,14 +182,13 @@ TEST(InstanceSegFilterTests, ClassFilterReturnsOnlyMatchingClasses) { } TEST(InstanceSegFilterTests, EmptyFilterReturnsAllClasses) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto allResults = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {}, true, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto allResults = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); EXPECT_FALSE(allResults.empty()); - auto noResults = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, - {50}, true, kMethodName); + auto noResults = + model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {50}, true, kMethodName); EXPECT_TRUE(noResults.empty()); } @@ -221,12 +196,9 @@ TEST(InstanceSegFilterTests, EmptyFilterReturnsAllClasses) { // Mask resolution tests // ============================================================================ TEST(InstanceSegMaskTests, LowResMaskIsSmallerThanOriginal) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); - auto hiRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, - true, kMethodName); - auto loRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, - false, kMethodName); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); + auto hiRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); + auto loRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, false, kMethodName); if (!hiRes.empty() && !loRes.empty()) { EXPECT_LE(loRes[0].mask->size(), hiRes[0].mask->size()); @@ -237,15 +209,13 @@ TEST(InstanceSegMaskTests, LowResMaskIsSmallerThanOriginal) { // NMS tests // ============================================================================ TEST(InstanceSegNMSTests, NMSEnabledReturnsFewerOrEqualResults) { - BaseInstanceSegmentation modelWithNMS(kValidInstanceSegModelPath, {}, {}, - true, nullptr); - BaseInstanceSegmentation modelWithoutNMS(kValidInstanceSegModelPath, {}, {}, - false, nullptr); + BaseInstanceSegmentation modelWithNMS(kValidInstanceSegModelPath, {}, {}, true, nullptr); + BaseInstanceSegmentation modelWithoutNMS(kValidInstanceSegModelPath, {}, {}, false, nullptr); - auto nmsResults = modelWithNMS.generateFromString( - kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); - auto noNmsResults = modelWithoutNMS.generateFromString( - kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); + auto nmsResults = + modelWithNMS.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); + auto noNmsResults = + modelWithoutNMS.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName); EXPECT_LE(nmsResults.size(), noNmsResults.size()); } @@ -254,41 +224,32 @@ TEST(InstanceSegNMSTests, NMSEnabledReturnsFewerOrEqualResults) { // generateFromPixels tests // ============================================================================ TEST(InstanceSegPixelTests, ValidPixelDataReturnsResults) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - auto results = model.generateFromPixels(tensorView, 0.3, 0.5, 100, {}, true, - kMethodName); + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; + auto results = model.generateFromPixels(tensorView, 0.3, 0.5, 100, {}, true, kMethodName); EXPECT_GE(results.size(), 0u); } TEST(InstanceSegPixelTests, NegativeConfidenceThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1, 0.5, 100, {}, - true, kMethodName), + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1, 0.5, 100, {}, true, kMethodName), RnExecutorchError); } TEST(InstanceSegPixelTests, ConfidenceAboveOneThrows) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1, 0.5, 100, {}, - true, kMethodName), + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1, 0.5, 100, {}, true, kMethodName), RnExecutorchError); } @@ -296,8 +257,7 @@ TEST(InstanceSegPixelTests, ConfidenceAboveOneThrows) { // Inherited method tests // ============================================================================ TEST(InstanceSegInheritedTests, GetInputShapeWorks) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto shape = model.getInputShape(kMethodName, 0); EXPECT_EQ(shape.size(), 4); EXPECT_EQ(shape[0], 1); @@ -305,15 +265,13 @@ TEST(InstanceSegInheritedTests, GetInputShapeWorks) { } TEST(InstanceSegInheritedTests, GetAllInputShapesWorks) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto shapes = model.getAllInputShapes(kMethodName); EXPECT_FALSE(shapes.empty()); } TEST(InstanceSegInheritedTests, GetMethodMetaWorks) { - BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, - nullptr); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true, nullptr); auto result = model.getMethodMeta(kMethodName); EXPECT_TRUE(result.ok()); } @@ -324,15 +282,13 @@ TEST(InstanceSegInheritedTests, GetMethodMetaWorks) { TEST(InstanceSegNormTests, ValidNormParamsDoesntThrow) { const std::vector mean = {0.485f, 0.456f, 0.406f}; const std::vector std = {0.229f, 0.224f, 0.225f}; - EXPECT_NO_THROW(BaseInstanceSegmentation(kValidInstanceSegModelPath, mean, - std, true, nullptr)); + EXPECT_NO_THROW(BaseInstanceSegmentation(kValidInstanceSegModelPath, mean, std, true, nullptr)); } TEST(InstanceSegNormTests, ValidNormParamsGenerateSucceeds) { const std::vector mean = {0.485f, 0.456f, 0.406f}; const std::vector std = {0.229f, 0.224f, 0.225f}; - BaseInstanceSegmentation model(kValidInstanceSegModelPath, mean, std, true, - nullptr); - EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, - 100, {}, true, kMethodName)); + BaseInstanceSegmentation model(kValidInstanceSegModelPath, mean, std, true, nullptr); + EXPECT_NO_THROW( + (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, 100, {}, true, kMethodName)); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp index 4b34f4248e..5a143f565c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/LLMTest.cpp @@ -27,17 +27,14 @@ std::shared_ptr createMockCallInvoker(); } // Helper to format prompt in ChatML format for SmolLM2 -std::string formatChatML(const std::string &systemPrompt, - const std::string &userMessage) { - return "<|im_start|>system\n" + systemPrompt + "<|im_end|>\n" + - "<|im_start|>user\n" + userMessage + "<|im_end|>\n" + - "<|im_start|>assistant\n"; +std::string formatChatML(const std::string &systemPrompt, const std::string &userMessage) { + return "<|im_start|>system\n" + systemPrompt + "<|im_end|>\n" + "<|im_start|>user\n" + + userMessage + "<|im_end|>\n" + "<|im_start|>assistant\n"; } // Helper to format a single-turn prompt in Gemma's chat template. std::string formatGemma(const std::string &userMessage) { - return "user\n" + userMessage + "\n" + - "model\n"; + return "user\n" + userMessage + "\n" + "model\n"; } // ============================================================================ @@ -78,14 +75,12 @@ class LLMTest : public ::testing::Test { }; TEST(LLMCtorTests, InvalidTokenizerPathThrows) { - EXPECT_THROW(LLM(kValidModelPath, "nonexistent_tokenizer.json", {}, - createMockCallInvoker()), + EXPECT_THROW(LLM(kValidModelPath, "nonexistent_tokenizer.json", {}, createMockCallInvoker()), RnExecutorchError); } TEST(LLMCtorTests, WrongCapabilitiesThrowsClearError) { - EXPECT_THROW(LLM(kValidModelPath, kValidTokenizerPath, {"vision"}, - createMockCallInvoker()), + EXPECT_THROW(LLM(kValidModelPath, kValidTokenizerPath, {"vision"}, createMockCallInvoker()), rnexecutorch::RnExecutorchError); } @@ -176,8 +171,7 @@ TEST_F(LLMTest, SettersThrowWhenUnloaded) { TEST_F(LLMTest, GenerateProducesValidOutput) { LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); model.setTemperature(0.0f); - std::string prompt = - formatChatML(kSystemPrompt, "Repeat exactly this: `naszponcilem testy`"); + std::string prompt = formatChatML(kSystemPrompt, "Repeat exactly this: `naszponcilem testy`"); std::string output = model.generate(prompt, nullptr); EXPECT_EQ(output, "`naszponcilem testy`<|im_end|>"); } @@ -185,8 +179,7 @@ TEST_F(LLMTest, GenerateProducesValidOutput) { TEST_F(LLMTest, GenerateUpdatesTokenCount) { LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); EXPECT_EQ(model.getGeneratedTokenCount(), 0); - std::string prompt = - formatChatML(kSystemPrompt, "Repeat exactly this: 'naszponcilem testy'"); + std::string prompt = formatChatML(kSystemPrompt, "Repeat exactly this: 'naszponcilem testy'"); model.generate(prompt, nullptr); EXPECT_GT(model.getGeneratedTokenCount(), 0); } @@ -228,11 +221,9 @@ TEST_F(LLMTest, PromptTokenCountNonZeroAfterGenerate) { TEST(VisionEncoderTest, LoadFailsWithClearErrorWhenMethodMissing) { // smolLm2_135M_8da4w.pte has no vision_encoder method auto module = std::make_unique<::executorch::extension::Module>( - "smolLm2_135M_8da4w.pte", - ::executorch::extension::Module::LoadMode::File); + "smolLm2_135M_8da4w.pte", ::executorch::extension::Module::LoadMode::File); - auto encoder = - std::make_unique(*module); + auto encoder = std::make_unique(*module); EXPECT_THROW(encoder->load(), rnexecutorch::RnExecutorchError); } @@ -240,11 +231,9 @@ TEST(VisionEncoderTest, LoadFailsWithClearErrorWhenMethodMissing) { TEST(AudioEncoderTest, LoadFailsWithClearErrorWhenMethodMissing) { // smolLm2_135M_8da4w.pte has no audio_encoder method auto module = std::make_unique<::executorch::extension::Module>( - "smolLm2_135M_8da4w.pte", - ::executorch::extension::Module::LoadMode::File); + "smolLm2_135M_8da4w.pte", ::executorch::extension::Module::LoadMode::File); - auto encoder = - std::make_unique(*module); + auto encoder = std::make_unique(*module); EXPECT_THROW(encoder->load(), rnexecutorch::RnExecutorchError); } @@ -255,8 +244,7 @@ TEST(AudioEncoderTest, LoadFailsWithClearErrorWhenMethodMissing) { constexpr auto kVlmModelPath = "lfm2_5_vl_quantized_xnnpack_v2.pte"; constexpr auto kVlmTokenizerPath = "lfm2_vl_tokenizer.json"; constexpr auto kVlmImageToken = ""; -constexpr auto kTestImagePath = - "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +constexpr auto kTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; TEST_F(LLMTest, TextModelIsNotMultimodal) { LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); @@ -267,10 +255,8 @@ TEST_F(LLMTest, GenerateMultimodalOnTextModelThrows) { LLM model(kValidModelPath, kValidTokenizerPath, {}, mockInvoker_); // A text-only runner reports is_multimodal() == false, so any multimodal // call must be rejected before the inputs are even inspected. - MultimodalInputs inputs{.images = - ImageInputs{.paths = {}, .token = ""}}; - EXPECT_THROW(model.generateMultimodal("hello", nullptr, std::move(inputs)), - RnExecutorchError); + MultimodalInputs inputs{.images = ImageInputs{.paths = {}, .token = ""}}; + EXPECT_THROW(model.generateMultimodal("hello", nullptr, std::move(inputs)), RnExecutorchError); } // Fixture that loads the VLM model once for all VLM tests @@ -278,9 +264,8 @@ class VLMTest : public ::testing::Test { protected: static void SetUpTestSuite() { invoker_ = createMockCallInvoker(); - model_ = - std::make_unique(kVlmModelPath, kVlmTokenizerPath, - std::vector{"vision"}, invoker_); + model_ = std::make_unique(kVlmModelPath, kVlmTokenizerPath, + std::vector{"vision"}, invoker_); } static void TearDownTestSuite() { @@ -296,27 +281,22 @@ std::shared_ptr VLMTest::invoker_; std::unique_ptr VLMTest::model_; TEST_F(VLMTest, GenerateMultimodalEmptyImageTokenThrows) { - MultimodalInputs inputs{ - .images = ImageInputs{.paths = {kTestImagePath}, .token = ""}}; - EXPECT_THROW(model_->generateMultimodal("hello", nullptr, std::move(inputs)), - RnExecutorchError); + MultimodalInputs inputs{.images = ImageInputs{.paths = {kTestImagePath}, .token = ""}}; + EXPECT_THROW(model_->generateMultimodal("hello", nullptr, std::move(inputs)), RnExecutorchError); } TEST_F(VLMTest, GenerateMultimodalMorePlaceholdersThanImagePaths) { std::string prompt = std::string(kVlmImageToken) + " and " + kVlmImageToken; - MultimodalInputs inputs{.images = ImageInputs{.paths = {kTestImagePath}, - .token = kVlmImageToken}}; - EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), - RnExecutorchError); + MultimodalInputs inputs{.images = + ImageInputs{.paths = {kTestImagePath}, .token = kVlmImageToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), RnExecutorchError); } TEST_F(VLMTest, GenerateMultimodalMoreImagePathsThanPlaceholders) { std::string prompt = std::string(kVlmImageToken) + " describe"; MultimodalInputs inputs{ - .images = ImageInputs{.paths = {kTestImagePath, kTestImagePath}, - .token = kVlmImageToken}}; - EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), - RnExecutorchError); + .images = ImageInputs{.paths = {kTestImagePath, kTestImagePath}, .token = kVlmImageToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), RnExecutorchError); } // ============================================================================ @@ -333,8 +313,7 @@ class GemmaAudioTest : public ::testing::Test { static void SetUpTestSuite() { invoker_ = createMockCallInvoker(); model_ = std::make_unique(kGemmaModelPath, kGemmaTokenizerPath, - std::vector{"vision", "audio"}, - invoker_); + std::vector{"vision", "audio"}, invoker_); } static void TearDownTestSuite() { @@ -358,58 +337,47 @@ std::shared_ptr GemmaAudioTest::invoker_; std::unique_ptr GemmaAudioTest::model_; TEST_F(GemmaAudioTest, GenerateMultimodalNoInputsThrows) { - EXPECT_THROW(model_->generateMultimodal("hello", nullptr, {}), - RnExecutorchError); + EXPECT_THROW(model_->generateMultimodal("hello", nullptr, {}), RnExecutorchError); } TEST_F(GemmaAudioTest, GenerateMultimodalEmptyAudioTokenThrows) { - MultimodalInputs inputs{ - .audios = AudioInputs{.waveforms = {loadAudio()}, .token = ""}}; - EXPECT_THROW(model_->generateMultimodal("hello", nullptr, std::move(inputs)), - RnExecutorchError); + MultimodalInputs inputs{.audios = AudioInputs{.waveforms = {loadAudio()}, .token = ""}}; + EXPECT_THROW(model_->generateMultimodal("hello", nullptr, std::move(inputs)), RnExecutorchError); } TEST_F(GemmaAudioTest, GenerateMultimodalMorePlaceholdersThanWaveformsThrows) { - std::string prompt = - std::string(kGemmaAudioToken) + " and " + kGemmaAudioToken; - MultimodalInputs inputs{.audios = AudioInputs{.waveforms = {loadAudio()}, - .token = kGemmaAudioToken}}; - EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), - RnExecutorchError); + std::string prompt = std::string(kGemmaAudioToken) + " and " + kGemmaAudioToken; + MultimodalInputs inputs{.audios = + AudioInputs{.waveforms = {loadAudio()}, .token = kGemmaAudioToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), RnExecutorchError); } TEST_F(GemmaAudioTest, GenerateMultimodalMoreWaveformsThanPlaceholdersThrows) { std::string prompt = std::string(kGemmaAudioToken) + " describe"; MultimodalInputs inputs{ - .audios = AudioInputs{.waveforms = {loadAudio(), loadAudio()}, - .token = kGemmaAudioToken}}; - EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), - RnExecutorchError); + .audios = AudioInputs{.waveforms = {loadAudio(), loadAudio()}, .token = kGemmaAudioToken}}; + EXPECT_THROW(model_->generateMultimodal(prompt, nullptr, std::move(inputs)), RnExecutorchError); } TEST_F(GemmaAudioTest, GenerateMultimodalAudioProducesOutput) { std::vector wav = loadAudio(); - ASSERT_FALSE(wav.empty()) - << "test_audio_float.raw missing on device - check run_tests.sh assets"; + ASSERT_FALSE(wav.empty()) << "test_audio_float.raw missing on device - check run_tests.sh assets"; - std::string prompt = - formatGemma(std::string(kGemmaAudioToken) + " Transcribe the audio."); - MultimodalInputs inputs{.audios = AudioInputs{.waveforms = {std::move(wav)}, - .token = kGemmaAudioToken}}; - std::string output = - model_->generateMultimodal(prompt, nullptr, std::move(inputs)); + std::string prompt = formatGemma(std::string(kGemmaAudioToken) + " Transcribe the audio."); + MultimodalInputs inputs{ + .audios = AudioInputs{.waveforms = {std::move(wav)}, .token = kGemmaAudioToken}}; + std::string output = model_->generateMultimodal(prompt, nullptr, std::move(inputs)); EXPECT_FALSE(output.empty()); EXPECT_GT(model_->getGeneratedTokenCount(), 0); } TEST_F(GemmaAudioTest, GenerateMultimodalInterleavedTextAndAudio) { - std::string prompt = formatGemma("Listen: " + std::string(kGemmaAudioToken) + - " then summarise it."); - MultimodalInputs inputs{.audios = AudioInputs{.waveforms = {loadAudio()}, - .token = kGemmaAudioToken}}; - std::string output = - model_->generateMultimodal(prompt, nullptr, std::move(inputs)); + std::string prompt = + formatGemma("Listen: " + std::string(kGemmaAudioToken) + " then summarise it."); + MultimodalInputs inputs{.audios = + AudioInputs{.waveforms = {loadAudio()}, .token = kGemmaAudioToken}}; + std::string output = model_->generateMultimodal(prompt, nullptr, std::move(inputs)); EXPECT_FALSE(output.empty()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp index a97e4c2121..55e87e3214 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/OCRTest.cpp @@ -20,10 +20,9 @@ constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/we_are_software_mansion.jpg"; // English alphabet symbols (must match alphabets.english from symbols.ts) -const std::string ENGLISH_SYMBOLS = - "0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ " - "\xE2\x82\xAC" // Euro sign (€) - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; +const std::string ENGLISH_SYMBOLS = "0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ " + "\xE2\x82\xAC" // Euro sign (€) + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; // ============================================================================ // Common tests via typed test suite @@ -55,48 +54,39 @@ INSTANTIATE_TYPED_TEST_SUITE_P(OCR, CommonModelTest, OCRTypes); // Model-specific tests // ============================================================================ TEST(OCRCtorTests, InvalidRecognizerPathThrows) { - EXPECT_THROW(OCR(kValidDetectorPath, "nonexistent.pte", ENGLISH_SYMBOLS, - createMockCallInvoker()), + EXPECT_THROW(OCR(kValidDetectorPath, "nonexistent.pte", ENGLISH_SYMBOLS, createMockCallInvoker()), RnExecutorchError); } TEST(OCRCtorTests, EmptySymbolsThrows) { - EXPECT_THROW(OCR(kValidDetectorPath, kValidRecognizerPath, "", - createMockCallInvoker()), + EXPECT_THROW(OCR(kValidDetectorPath, kValidRecognizerPath, "", createMockCallInvoker()), RnExecutorchError); } TEST(OCRGenerateTests, InvalidImagePathThrows) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), - RnExecutorchError); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(OCRGenerateTests, EmptyImagePathThrows) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(OCRGenerateTests, MalformedURIThrows) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), - RnExecutorchError); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(OCRGenerateTests, ValidImageReturnsResults) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); auto results = model.generateFromString(kValidTestImagePath); // May or may not have detections depending on image content EXPECT_GE(results.size(), 0u); } TEST(OCRGenerateTests, DetectionsHaveValidBoundingBoxes) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { @@ -108,8 +98,7 @@ TEST(OCRGenerateTests, DetectionsHaveValidBoundingBoxes) { } TEST(OCRGenerateTests, DetectionsHaveValidScores) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { @@ -119,8 +108,7 @@ TEST(OCRGenerateTests, DetectionsHaveValidScores) { } TEST(OCRGenerateTests, DetectionsHaveNonEmptyText) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); auto results = model.generateFromString(kValidTestImagePath); for (const auto &detection : results) { EXPECT_FALSE(detection.text.empty()); @@ -131,11 +119,9 @@ TEST(OCRGenerateTests, DetectionsHaveNonEmptyText) { // generateFromPixels smoke test // ============================================================================ TEST(OCRPixelTests, ValidPixelsReturnsResults) { - OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, - createMockCallInvoker()); + OCR model(kValidDetectorPath, kValidRecognizerPath, ENGLISH_SYMBOLS, createMockCallInvoker()); std::vector buf(64 * 64 * 3, 128); - JSTensorViewIn view{ - buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + JSTensorViewIn view{buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; auto results = model.generateFromPixels(view); EXPECT_GE(results.size(), 0u); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp index 5c5bb6e736..7f4e94ac32 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp @@ -11,10 +11,8 @@ using namespace rnexecutorch; using namespace rnexecutorch::models::object_detection; using namespace model_tests; -constexpr auto kValidObjectDetectionModelPath = - "ssdlite320-mobilenetv3-large.pte"; -constexpr auto kValidTestImagePath = - "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +constexpr auto kValidObjectDetectionModelPath = "ssdlite320-mobilenetv3-large.pte"; +constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; // clang-format off const std::vector kCocoLabels = { @@ -41,80 +39,61 @@ template <> struct ModelTraits { using ModelType = ObjectDetection; static ModelType createValid() { - return ModelType(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + return ModelType(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", {}, {}, {}, nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", {}, {}, {}, nullptr); } static void callGenerate(ModelType &model) { - (void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, - "forward"); + (void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, "forward"); } }; } // namespace model_tests using ObjectDetectionTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, CommonModelTest, - ObjectDetectionTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest, - ObjectDetectionTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, CommonModelTest, ObjectDetectionTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest, ObjectDetectionTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(ObjectDetectionGenerateTests, InvalidImagePathThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, - 0.55, {}, "forward"), + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.55, {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, EmptyImagePathThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - EXPECT_THROW((void)model.generateFromString("", 0.5, 0.55, {}, "forward"), - RnExecutorchError); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + EXPECT_THROW((void)model.generateFromString("", 0.5, 0.55, {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, MalformedURIThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5, - 0.55, {}, "forward"), + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5, 0.55, {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, NegativeThresholdThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.55, - {}, "forward"), + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.55, {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, ThresholdAboveOneThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.55, - {}, "forward"), + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.55, {}, "forward"), RnExecutorchError); } TEST(ObjectDetectionGenerateTests, ValidImageReturnsResults) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); EXPECT_GE(results.size(), 0u); } TEST(ObjectDetectionGenerateTests, HighThresholdReturnsFewerResults) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); auto lowThresholdResults = model.generateFromString(kValidTestImagePath, 0.1, 0.55, {}, "forward"); auto highThresholdResults = @@ -123,10 +102,8 @@ TEST(ObjectDetectionGenerateTests, HighThresholdReturnsFewerResults) { } TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); for (const auto &detection : results) { EXPECT_LE(detection.bbox.p1.x, detection.bbox.p2.x); @@ -137,10 +114,8 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidBoundingBoxes) { } TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); for (const auto &detection : results) { EXPECT_GE(detection.score, 0.0f); @@ -149,16 +124,13 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidScores) { } TEST(ObjectDetectionGenerateTests, DetectionsHaveValidLabels) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); for (const auto &detection : results) { const auto &label = detection.label; EXPECT_FALSE(label.empty()); - EXPECT_NE(std::find(kCocoLabels.begin(), kCocoLabels.end(), label), - kCocoLabels.end()); + EXPECT_NE(std::find(kCocoLabels.begin(), kCocoLabels.end(), label), kCocoLabels.end()); } } @@ -166,46 +138,37 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidLabels) { // generateFromPixels tests // ============================================================================ TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; auto results = model.generateFromPixels(tensorView, 0.3, 0.55, {}, "forward"); EXPECT_GE(results.size(), 0u); } TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW( - (void)model.generateFromPixels(tensorView, -0.1, 0.55, {}, "forward"), - RnExecutorchError); + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1, 0.55, {}, "forward"), + RnExecutorchError); } TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW( - (void)model.generateFromPixels(tensorView, 1.1, 0.55, {}, "forward"), - RnExecutorchError); + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1, 0.55, {}, "forward"), + RnExecutorchError); } TEST(ObjectDetectionInheritedTests, GetInputShapeWorks) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); auto shape = model.getInputShape("forward", 0); // v0.9.0 ssdlite is exported without a leading batch dim, so the signature // is (C, H, W) rather than the (1, C, H, W) used by sibling vision models. @@ -215,15 +178,13 @@ TEST(ObjectDetectionInheritedTests, GetInputShapeWorks) { } TEST(ObjectDetectionInheritedTests, GetAllInputShapesWorks) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); auto shapes = model.getAllInputShapes("forward"); EXPECT_FALSE(shapes.empty()); } TEST(ObjectDetectionInheritedTests, GetMethodMetaWorks) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } @@ -234,73 +195,58 @@ TEST(ObjectDetectionInheritedTests, GetMethodMetaWorks) { TEST(ObjectDetectionNormTests, ValidNormParamsDoesntThrow) { const std::vector mean = {0.485f, 0.456f, 0.406f}; const std::vector std = {0.229f, 0.224f, 0.225f}; - EXPECT_NO_THROW( - ObjectDetection(kValidObjectDetectionModelPath, mean, std, {}, nullptr)); + EXPECT_NO_THROW(ObjectDetection(kValidObjectDetectionModelPath, mean, std, {}, nullptr)); } TEST(ObjectDetectionNormTests, InvalidNormMeanSizeDoesntThrow) { - EXPECT_NO_THROW(ObjectDetection(kValidObjectDetectionModelPath, {0.5f}, - {0.229f, 0.224f, 0.225f}, {}, nullptr)); + EXPECT_NO_THROW(ObjectDetection(kValidObjectDetectionModelPath, {0.5f}, {0.229f, 0.224f, 0.225f}, + {}, nullptr)); } TEST(ObjectDetectionNormTests, InvalidNormStdSizeDoesntThrow) { - EXPECT_NO_THROW(ObjectDetection(kValidObjectDetectionModelPath, - {0.485f, 0.456f, 0.406f}, {0.5f}, {}, - nullptr)); + EXPECT_NO_THROW(ObjectDetection(kValidObjectDetectionModelPath, {0.485f, 0.456f, 0.406f}, {0.5f}, + {}, nullptr)); } TEST(ObjectDetectionNormTests, ValidNormParamsGenerateSucceeds) { const std::vector mean = {0.485f, 0.456f, 0.406f}; const std::vector std = {0.229f, 0.224f, 0.225f}; - ObjectDetection model(kValidObjectDetectionModelPath, mean, std, kCocoLabels, - nullptr); - EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, - {}, "forward")); + ObjectDetection model(kValidObjectDetectionModelPath, mean, std, kCocoLabels, nullptr); + EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, "forward")); } // ============================================================================ // Method name tests // ============================================================================ TEST(ObjectDetectionMethodTests, InvalidMethodNameThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, - {}, "forward_999"), + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, "forward_999"), RnExecutorchError); } TEST(ObjectDetectionMethodTests, EmptyMethodNameThrows) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - EXPECT_THROW( - (void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, ""), - RnExecutorchError); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.55, {}, ""), + RnExecutorchError); } // ============================================================================ // Class indices filtering tests // ============================================================================ -TEST(ObjectDetectionClassFilterTests, - FilteredResultsOnlyContainRequestedClasses) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); +TEST(ObjectDetectionClassFilterTests, FilteredResultsOnlyContainRequestedClasses) { + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); // Only request "person" class (index 0 in COCO) - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward"); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward"); for (const auto &det : results) { EXPECT_EQ(det.label, "person"); } } -TEST(ObjectDetectionClassFilterTests, - EmptyClassIndicesReturnsMoreOrEqualResults) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); - auto allClasses = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); +TEST(ObjectDetectionClassFilterTests, EmptyClassIndicesReturnsMoreOrEqualResults) { + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); + auto allClasses = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {}, "forward"); // person (0) only - auto filtered = - model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward"); + auto filtered = model.generateFromString(kValidTestImagePath, 0.3, 0.55, {0}, "forward"); EXPECT_GE(allClasses.size(), filtered.size()); } @@ -308,13 +254,10 @@ TEST(ObjectDetectionClassFilterTests, // IoU threshold tests // ============================================================================ TEST(ObjectDetectionIouTests, HigherIouThresholdReturnsSameOrMoreResults) { - ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, - nullptr); + ObjectDetection model(kValidObjectDetectionModelPath, {}, {}, kCocoLabels, nullptr); // High IoU threshold = less aggressive NMS = more boxes survive - auto highIou = - model.generateFromString(kValidTestImagePath, 0.3, 0.9, {}, "forward"); + auto highIou = model.generateFromString(kValidTestImagePath, 0.3, 0.9, {}, "forward"); // Low IoU threshold = more aggressive NMS = fewer boxes survive - auto lowIou = - model.generateFromString(kValidTestImagePath, 0.3, 0.1, {}, "forward"); + auto lowIou = model.generateFromString(kValidTestImagePath, 0.3, 0.1, {}, "forward"); EXPECT_GE(highIou.size(), lowIou.size()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp index 81f7100512..15c1918b8c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp @@ -22,13 +22,9 @@ namespace model_tests { template <> struct ModelTraits { using ModelType = PoseEstimation; - static ModelType createValid() { - return ModelType(kValidPoseModelPath, {}, {}, nullptr); - } + static ModelType createValid() { return ModelType(kValidPoseModelPath, {}, {}, nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", {}, {}, nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", {}, {}, nullptr); } static void callGenerate(ModelType &model) { (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, kMethodName); @@ -37,31 +33,26 @@ template <> struct ModelTraits { } // namespace model_tests using PoseEstimationTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(PoseEstimation, CommonModelTest, - PoseEstimationTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(PoseEstimation, VisionModelTest, - PoseEstimationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(PoseEstimation, CommonModelTest, PoseEstimationTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(PoseEstimation, VisionModelTest, PoseEstimationTypes); // ============================================================================ // generateFromString — input path validity // ============================================================================ TEST(PoseEstimationGenerateTests, InvalidImagePathThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5, - kMethodName), + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5, kMethodName), RnExecutorchError); } TEST(PoseEstimationGenerateTests, EmptyImagePathThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString("", 0.5, 0.5, kMethodName), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("", 0.5, 0.5, kMethodName), RnExecutorchError); } TEST(PoseEstimationGenerateTests, MalformedURIThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5, 0.5, - kMethodName), + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", 0.5, 0.5, kMethodName), RnExecutorchError); } @@ -70,29 +61,25 @@ TEST(PoseEstimationGenerateTests, MalformedURIThrows) { // ============================================================================ TEST(PoseEstimationGenerateTests, NegativeDetectionThresholdThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.5, - kMethodName), + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.5, kMethodName), RnExecutorchError); } TEST(PoseEstimationGenerateTests, DetectionThresholdAboveOneThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.5, - kMethodName), + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.5, kMethodName), RnExecutorchError); } TEST(PoseEstimationGenerateTests, NegativeKeypointThresholdThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, -0.1, - kMethodName), + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, -0.1, kMethodName), RnExecutorchError); } TEST(PoseEstimationGenerateTests, KeypointThresholdAboveOneThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 1.1, - kMethodName), + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 1.1, kMethodName), RnExecutorchError); } @@ -101,24 +88,20 @@ TEST(PoseEstimationGenerateTests, KeypointThresholdAboveOneThrows) { // ============================================================================ TEST(PoseEstimationGenerateTests, ValidImageReturnsResults) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.5, kMethodName); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, kMethodName); EXPECT_GE(results.size(), 0u); } TEST(PoseEstimationGenerateTests, HighThresholdReturnsFewerResults) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - auto lowThresholdResults = - model.generateFromString(kValidTestImagePath, 0.1, 0.5, kMethodName); - auto highThresholdResults = - model.generateFromString(kValidTestImagePath, 0.95, 0.5, kMethodName); + auto lowThresholdResults = model.generateFromString(kValidTestImagePath, 0.1, 0.5, kMethodName); + auto highThresholdResults = model.generateFromString(kValidTestImagePath, 0.95, 0.5, kMethodName); EXPECT_GE(lowThresholdResults.size(), highThresholdResults.size()); } TEST(PoseEstimationGenerateTests, AllDetectionsHaveSameKeypointCount) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.1, 0.5, kMethodName); + auto results = model.generateFromString(kValidTestImagePath, 0.1, 0.5, kMethodName); if (results.size() < 2) { GTEST_SKIP() << "Need at least 2 detections to compare keypoint counts"; } @@ -131,8 +114,7 @@ TEST(PoseEstimationGenerateTests, AllDetectionsHaveSameKeypointCount) { TEST(PoseEstimationGenerateTests, KeypointsHaveValidStructure) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - auto results = - model.generateFromString(kValidTestImagePath, 0.3, 0.5, kMethodName); + auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, kMethodName); // Each detection must contain a non-zero number of keypoints, and each // keypoint must be aggregate-initializable as { x, y } floats (compile-time). for (const auto &person : results) { @@ -154,9 +136,8 @@ TEST(PoseEstimationPixelTests, ValidPixelDataReturnsResults) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; auto results = model.generateFromPixels(tensorView, 0.3, 0.5, kMethodName); EXPECT_GE(results.size(), 0u); } @@ -165,24 +146,20 @@ TEST(PoseEstimationPixelTests, NegativeThresholdThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW( - (void)model.generateFromPixels(tensorView, -0.1, 0.5, kMethodName), - RnExecutorchError); + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1, 0.5, kMethodName), + RnExecutorchError); } TEST(PoseEstimationPixelTests, ThresholdAboveOneThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); constexpr int32_t width = 4, height = 4, channels = 3; std::vector pixelData(width * height * channels, 128); - JSTensorViewIn tensorView{pixelData.data(), - {height, width, channels}, - executorch::aten::ScalarType::Byte}; - EXPECT_THROW( - (void)model.generateFromPixels(tensorView, 1.1, 0.5, kMethodName), - RnExecutorchError); + JSTensorViewIn tensorView{ + pixelData.data(), {height, width, channels}, executorch::aten::ScalarType::Byte}; + EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1, 0.5, kMethodName), + RnExecutorchError); } // ============================================================================ @@ -190,16 +167,14 @@ TEST(PoseEstimationPixelTests, ThresholdAboveOneThrows) { // ============================================================================ TEST(PoseEstimationMethodTests, InvalidMethodNameThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, - "forward_999"), + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, "forward_999"), RnExecutorchError); } TEST(PoseEstimationMethodTests, EmptyMethodNameThrows) { PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr); - EXPECT_THROW( - (void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, ""), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, ""), + RnExecutorchError); } // ============================================================================ @@ -212,21 +187,18 @@ TEST(PoseEstimationNormTests, ValidNormParamsDoesntThrow) { } TEST(PoseEstimationNormTests, InvalidNormMeanSizeDoesntThrow) { - EXPECT_NO_THROW(PoseEstimation(kValidPoseModelPath, {0.5f}, - {0.229f, 0.224f, 0.225f}, nullptr)); + EXPECT_NO_THROW(PoseEstimation(kValidPoseModelPath, {0.5f}, {0.229f, 0.224f, 0.225f}, nullptr)); } TEST(PoseEstimationNormTests, InvalidNormStdSizeDoesntThrow) { - EXPECT_NO_THROW(PoseEstimation(kValidPoseModelPath, {0.485f, 0.456f, 0.406f}, - {0.5f}, nullptr)); + EXPECT_NO_THROW(PoseEstimation(kValidPoseModelPath, {0.485f, 0.456f, 0.406f}, {0.5f}, nullptr)); } TEST(PoseEstimationNormTests, ValidNormParamsGenerateSucceeds) { const std::vector mean = {0.485f, 0.456f, 0.406f}; const std::vector std = {0.229f, 0.224f, 0.225f}; PoseEstimation model(kValidPoseModelPath, mean, std, nullptr); - EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, - kMethodName)); + EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, kMethodName)); } // ============================================================================ diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp index 09c5d42f67..142c21d1cd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp @@ -14,37 +14,31 @@ using executorch::extension::make_tensor_ptr; using executorch::extension::TensorPtr; using executorch::runtime::EValue; -constexpr auto kValidSemanticSegmentationModelPath = - "deeplabV3_xnnpack_fp32.pte"; -constexpr auto kValidTestImagePath = - "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +constexpr auto kValidSemanticSegmentationModelPath = "deeplabV3_xnnpack_fp32.pte"; +constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; // DeepLab V3 class labels (Pascal VOC) static const std::vector kDeeplabV3Labels = { - "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", - "BOTTLE", "BUS", "CAR", "CAT", "CHAIR", - "COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE", - "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", - "TVMONITOR"}; + "BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT", "BOTTLE", "BUS", + "CAR", "CAT", "CHAIR", "COW", "DININGTABLE", "DOG", "HORSE", + "MOTORBIKE", "PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN", "TVMONITOR"}; // ImageNet normalization constants static const std::vector kImageNetMean = {0.485f, 0.456f, 0.406f}; static const std::vector kImageNetStd = {0.229f, 0.224f, 0.225f}; -static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, - int32_t w) { +static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, int32_t w) { buf.assign(static_cast(h * w * 3), 128); - return JSTensorViewIn{ - buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; + return JSTensorViewIn{buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; } // Test fixture for tests that need dummy input data class SemanticSegmentationForwardTest : public ::testing::Test { protected: void SetUp() override { - model = std::make_unique( - kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, - kDeeplabV3Labels, nullptr); + model = std::make_unique(kValidSemanticSegmentationModelPath, + kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto shapes = model->getAllInputShapes("forward"); ASSERT_FALSE(shapes.empty()); shape = shapes[0]; @@ -56,8 +50,7 @@ class SemanticSegmentationForwardTest : public ::testing::Test { dummyData = std::vector(numElements, 0.5f); sizes = std::vector(shape.begin(), shape.end()); - inputTensor = - make_tensor_ptr(sizes, dummyData.data(), exec_aten::ScalarType::Float); + inputTensor = make_tensor_ptr(sizes, dummyData.data(), exec_aten::ScalarType::Float); } std::unique_ptr model; @@ -68,16 +61,14 @@ class SemanticSegmentationForwardTest : public ::testing::Test { }; TEST(SemanticSegmentationCtorTests, InvalidPathThrows) { - EXPECT_THROW(BaseSemanticSegmentation("this_file_does_not_exist.pte", - kImageNetMean, kImageNetStd, + EXPECT_THROW(BaseSemanticSegmentation("this_file_does_not_exist.pte", kImageNetMean, kImageNetStd, kDeeplabV3Labels, nullptr), RnExecutorchError); } TEST(SemanticSegmentationCtorTests, ValidPathDoesntThrow) { - EXPECT_NO_THROW(BaseSemanticSegmentation(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, - kDeeplabV3Labels, nullptr)); + EXPECT_NO_THROW(BaseSemanticSegmentation(kValidSemanticSegmentationModelPath, kImageNetMean, + kImageNetStd, kDeeplabV3Labels, nullptr)); } TEST_F(SemanticSegmentationForwardTest, ForwardWithValidTensorSucceeds) { @@ -124,52 +115,44 @@ TEST_F(SemanticSegmentationForwardTest, ForwardAfterUnloadThrows) { // generateFromString tests // ============================================================================ TEST(SemanticSegmentationGenerateTests, InvalidImagePathThrows) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); - EXPECT_THROW( - (void)model.generateFromString("nonexistent_image.jpg", {}, true), - RnExecutorchError); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", {}, true), + RnExecutorchError); } TEST(SemanticSegmentationGenerateTests, EmptyImagePathThrows) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); EXPECT_THROW((void)model.generateFromString("", {}, true), RnExecutorchError); } TEST(SemanticSegmentationGenerateTests, MalformedURIThrows) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); - EXPECT_THROW( - (void)model.generateFromString("not_a_valid_uri://bad", {}, true), - RnExecutorchError); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", {}, true), + RnExecutorchError); } TEST(SemanticSegmentationGenerateTests, ValidImageNoFilterReturnsResult) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto result = model.generateFromString(kValidTestImagePath, {}, true); EXPECT_NE(result.argmax, nullptr); EXPECT_NE(result.classBuffers, nullptr); } TEST(SemanticSegmentationGenerateTests, ValidImageReturnsAllClasses) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto result = model.generateFromString(kValidTestImagePath, {}, true); ASSERT_NE(result.classBuffers, nullptr); EXPECT_EQ(result.classBuffers->size(), 21u); } TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); std::set> filter = {"PERSON", "CAT"}; auto result = model.generateFromString(kValidTestImagePath, filter, true); ASSERT_NE(result.classBuffers, nullptr); @@ -180,9 +163,8 @@ TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) { } TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto result = model.generateFromString(kValidTestImagePath, {}, false); EXPECT_NE(result.argmax, nullptr); } @@ -191,9 +173,8 @@ TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) { // generateFromPixels tests // ============================================================================ TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); auto result = model.generateFromPixels(view, {}, true); @@ -202,9 +183,8 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) { } TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); auto result = model.generateFromPixels(view, {}, true); @@ -213,9 +193,8 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) { } TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); std::set> filter = {"PERSON"}; @@ -230,9 +209,8 @@ TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) { // Inherited BaseModel tests // ============================================================================ TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto shape = model.getInputShape("forward", 0); EXPECT_EQ(shape.size(), 4); EXPECT_EQ(shape[0], 1); // Batch size @@ -240,32 +218,28 @@ TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) { } TEST(SemanticSegmentationInheritedTests, GetAllInputShapesWorks) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto shapes = model.getAllInputShapes("forward"); EXPECT_FALSE(shapes.empty()); } TEST(SemanticSegmentationInheritedTests, GetMethodMetaWorks) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } TEST(SemanticSegmentationInheritedTests, GetMemoryLowerBoundReturnsPositive) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); EXPECT_GT(model.getMemoryLowerBound(), 0u); } TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) { - BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, - kImageNetMean, kImageNetStd, kDeeplabV3Labels, - nullptr); + BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd, + kDeeplabV3Labels, nullptr); auto shape = model.getInputShape("forward", 0); EXPECT_EQ(shape[2], shape[3]); // Height == Width for DeepLabV3 } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp index cc326d8bc0..b2348937cb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/SpeechToTextTest.cpp @@ -46,15 +46,14 @@ template <> struct ModelTraits { } // namespace model_tests using SpeechToTextTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(SpeechToText, CommonModelTest, - SpeechToTextTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(SpeechToText, CommonModelTest, SpeechToTextTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(S2TCtorTests, InvalidModelNameThrows) { - EXPECT_THROW(SpeechToText("invalid_model", kValidModelPath, - kValidTokenizerPath, /*vadSource=*/"", nullptr), + EXPECT_THROW(SpeechToText("invalid_model", kValidModelPath, kValidTokenizerPath, /*vadSource=*/"", + nullptr), RnExecutorchError); } @@ -111,29 +110,28 @@ TEST(S2TTranscribeTests, InvalidLanguageThrows) { /*vadSource=*/"", nullptr); auto audio = loadAudioFromFile("test_audio_float.raw"); ASSERT_FALSE(audio.empty()); - EXPECT_THROW((void)model.transcribe(audio, "invalid_language_code", false), - RnExecutorchError); + EXPECT_THROW((void)model.transcribe(audio, "invalid_language_code", false), RnExecutorchError); } // ============================================================================ // VAD integration tests (vadSource provided => internal VAD module loaded) // ============================================================================ TEST(S2TVadCtorTests, ValidVadSourceConstructs) { - EXPECT_NO_THROW(SpeechToText("whisper", kValidModelPath, kValidTokenizerPath, - kValidVadModelPath, createMockCallInvoker())); + EXPECT_NO_THROW(SpeechToText("whisper", kValidModelPath, kValidTokenizerPath, kValidVadModelPath, + createMockCallInvoker())); } TEST(S2TVadCtorTests, InvalidVadSourceThrows) { - EXPECT_THROW(SpeechToText("whisper", kValidModelPath, kValidTokenizerPath, - "nonexistent_vad.pte", createMockCallInvoker()), + EXPECT_THROW(SpeechToText("whisper", kValidModelPath, kValidTokenizerPath, "nonexistent_vad.pte", + createMockCallInvoker()), RnExecutorchError); } TEST(S2TVadTranscribeTests, TranscribeStillWorksWithVadLoaded) { // The vadSource only affects streaming. The one-shot transcribe() path // must remain unchanged when a VAD is attached. - SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, - kValidVadModelPath, createMockCallInvoker()); + SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, kValidVadModelPath, + createMockCallInvoker()); auto silence = generateSilence(16000 * 5); auto result = model.transcribe(silence, "en", false); EXPECT_TRUE(result.text.empty()); @@ -154,8 +152,8 @@ TEST(S2TVadStreamTests, StreamWithVadOnSilenceCompletesCleanly) { // Drives the OnlineASR::process VAD branch with audio that contains no // speech segments. Exercises the "speechSegments.empty()" cleanup path // added by the PR. - SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, - kValidVadModelPath, createMockCallInvoker()); + SpeechToText model("whisper", kValidModelPath, kValidTokenizerPath, kValidVadModelPath, + createMockCallInvoker()); std::thread streamer([&model] { model.stream(std::shared_ptr(), /*language=*/"en", /*verbose=*/false, /*timeout=*/100, /*useVAD=*/true, diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp index a4511cad11..27190f61f8 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp @@ -11,16 +11,12 @@ using namespace rnexecutorch; using namespace rnexecutorch::models::style_transfer; using namespace model_tests; -constexpr auto kValidStyleTransferModelPath = - "style_transfer_candy_xnnpack_fp32.pte"; -constexpr auto kValidTestImagePath = - "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; +constexpr auto kValidStyleTransferModelPath = "style_transfer_candy_xnnpack_fp32.pte"; +constexpr auto kValidTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/test_image.jpg"; -static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, - int32_t w) { +static JSTensorViewIn makeRgbView(std::vector &buf, int32_t h, int32_t w) { buf.assign(static_cast(h * w * 3), 128); - return JSTensorViewIn{ - buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; + return JSTensorViewIn{buf.data(), {h, w, 3}, executorch::aten::ScalarType::Byte}; } // ============================================================================ @@ -30,13 +26,9 @@ namespace model_tests { template <> struct ModelTraits { using ModelType = StyleTransfer; - static ModelType createValid() { - return ModelType(kValidStyleTransferModelPath, nullptr); - } + static ModelType createValid() { return ModelType(kValidStyleTransferModelPath, nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", nullptr); } static void callGenerate(ModelType &model) { (void)model.generateFromString(kValidTestImagePath, false); @@ -45,18 +37,15 @@ template <> struct ModelTraits { } // namespace model_tests using StyleTransferTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, CommonModelTest, - StyleTransferTypes); -INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, VisionModelTest, - StyleTransferTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, CommonModelTest, StyleTransferTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, VisionModelTest, StyleTransferTypes); // ============================================================================ // generateFromString tests // ============================================================================ TEST(StyleTransferGenerateTests, InvalidImagePathThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", false), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", false), RnExecutorchError); } TEST(StyleTransferGenerateTests, EmptyImagePathThrows) { @@ -66,8 +55,7 @@ TEST(StyleTransferGenerateTests, EmptyImagePathThrows) { TEST(StyleTransferGenerateTests, MalformedURIThrows) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", false), - RnExecutorchError); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad", false), RnExecutorchError); } TEST(StyleTransferGenerateTests, ValidImageReturnsFilePath) { @@ -142,8 +130,7 @@ TEST(StyleTransferPixelTests, ValidPixelsSaveToFileFalseHasPositiveDimensions) { EXPECT_GT(pr.height, 0); } -TEST(StyleTransferPixelTests, - ValidPixelsSaveToFileTrueReturnsFileSchemeString) { +TEST(StyleTransferPixelTests, ValidPixelsSaveToFileTrueReturnsFileSchemeString) { StyleTransfer model(kValidStyleTransferModelPath, nullptr); std::vector buf; auto view = makeRgbView(buf, 64, 64); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp index ff1abd4c30..0c2b774575 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp @@ -21,61 +21,52 @@ template <> struct ModelTraits { using ModelType = TextEmbeddings; static ModelType createValid() { - return ModelType(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + return ModelType(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); } static ModelType createInvalid() { - return ModelType("nonexistent.pte", kValidTextEmbeddingsTokenizerPath, - nullptr); + return ModelType("nonexistent.pte", kValidTextEmbeddingsTokenizerPath, nullptr); } - static void callGenerate(ModelType &model) { - (void)model.generate("Hello, world!"); - } + static void callGenerate(ModelType &model) { (void)model.generate("Hello, world!"); } }; } // namespace model_tests using TextEmbeddingsTypes = ::testing::Types; -INSTANTIATE_TYPED_TEST_SUITE_P(TextEmbeddings, CommonModelTest, - TextEmbeddingsTypes); +INSTANTIATE_TYPED_TEST_SUITE_P(TextEmbeddings, CommonModelTest, TextEmbeddingsTypes); // ============================================================================ // Model-specific tests // ============================================================================ TEST(TextEmbeddingsCtorTests, InvalidTokenizerPathThrows) { - EXPECT_THROW(TextEmbeddings(kValidTextEmbeddingsModelPath, - "this_tokenizer_does_not_exist.json", nullptr), - std::exception); + EXPECT_THROW( + TextEmbeddings(kValidTextEmbeddingsModelPath, "this_tokenizer_does_not_exist.json", nullptr), + std::exception); } TEST(TextEmbeddingsGenerateTests, EmptyStringReturnsResults) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate(""); EXPECT_NE(result, nullptr); EXPECT_GT(result->size(), 0u); } TEST(TextEmbeddingsGenerateTests, ValidTextReturnsResults) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("Hello, world!"); EXPECT_NE(result, nullptr); EXPECT_GT(result->size(), 0u); } TEST(TextEmbeddingsGenerateTests, ResultsHaveCorrectSize) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("This is a test sentence."); size_t numFloats = result->size() / sizeof(float); EXPECT_EQ(numFloats, kMiniLmEmbeddingDimensions); } TEST(TextEmbeddingsGenerateTests, ResultsAreNormalized) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("The quick brown fox jumps over the lazy dog."); const float *data = reinterpret_cast(result->data()); @@ -90,8 +81,7 @@ TEST(TextEmbeddingsGenerateTests, ResultsAreNormalized) { } TEST(TextEmbeddingsGenerateTests, ResultsContainValidValues) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("Testing valid values."); const float *data = reinterpret_cast(result->data()); @@ -104,8 +94,7 @@ TEST(TextEmbeddingsGenerateTests, ResultsContainValidValues) { } TEST(TextEmbeddingsGenerateTests, DifferentTextProducesDifferentEmbeddings) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result1 = model.generate("Hello, world!"); auto result2 = model.generate("Goodbye, moon!"); @@ -125,8 +114,7 @@ TEST(TextEmbeddingsGenerateTests, DifferentTextProducesDifferentEmbeddings) { } TEST(TextEmbeddingsGenerateTests, SimilarTextProducesSimilarEmbeddings) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result1 = model.generate("I love programming"); auto result2 = model.generate("I enjoy coding"); @@ -143,22 +131,19 @@ TEST(TextEmbeddingsGenerateTests, SimilarTextProducesSimilarEmbeddings) { } TEST(TextEmbeddingsInheritedTests, GetInputShapeWorks) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto shape = model.getInputShape("forward", 0); EXPECT_GE(shape.size(), 2u); } TEST(TextEmbeddingsInheritedTests, GetAllInputShapesWorks) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto shapes = model.getAllInputShapes("forward"); EXPECT_FALSE(shapes.empty()); } TEST(TextEmbeddingsInheritedTests, GetMethodMetaWorks) { - TextEmbeddings model(kValidTextEmbeddingsModelPath, - kValidTextEmbeddingsTokenizerPath, nullptr); + TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.getMethodMeta("forward"); EXPECT_TRUE(result.ok()); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToImageTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToImageTest.cpp index 7ecae1f392..5741912376 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToImageTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToImageTest.cpp @@ -32,22 +32,18 @@ template <> struct ModelTraits { using ModelType = TextToImage; static ModelType createValid() { - return ModelType(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, - kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, - kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, - rnexecutorch::createMockCallInvoker()); + return ModelType(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, kValidDecoderPath, + kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, + kSchedulerStepsOffset, rnexecutorch::createMockCallInvoker()); } static ModelType createInvalid() { - return ModelType("nonexistent.json", kValidEncoderPath, kValidUnetPath, - kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, - kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, - rnexecutorch::createMockCallInvoker()); + return ModelType("nonexistent.json", kValidEncoderPath, kValidUnetPath, kValidDecoderPath, + kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, + kSchedulerStepsOffset, rnexecutorch::createMockCallInvoker()); } - static void callGenerate(ModelType &model) { - (void)model.generate("a cat", 128, 1, 42, nullptr); - } + static void callGenerate(ModelType &model) { (void)model.generate("a cat", 128, 1, 42, nullptr); } }; } // namespace model_tests @@ -60,67 +56,57 @@ INSTANTIATE_TYPED_TEST_SUITE_P(TextToImage, CommonModelTest, TextToImageTypes); // Model-specific tests // ============================================================================ TEST(TextToImageCtorTests, InvalidEncoderPathThrows) { - EXPECT_THROW(TextToImage(kValidTokenizerPath, "nonexistent.pte", - kValidUnetPath, kValidDecoderPath, - kSchedulerBetaStart, kSchedulerBetaEnd, + EXPECT_THROW(TextToImage(kValidTokenizerPath, "nonexistent.pte", kValidUnetPath, + kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, createMockCallInvoker()), RnExecutorchError); } TEST(TextToImageCtorTests, InvalidUnetPathThrows) { - EXPECT_THROW(TextToImage(kValidTokenizerPath, kValidEncoderPath, - "nonexistent.pte", kValidDecoderPath, - kSchedulerBetaStart, kSchedulerBetaEnd, + EXPECT_THROW(TextToImage(kValidTokenizerPath, kValidEncoderPath, "nonexistent.pte", + kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, createMockCallInvoker()), RnExecutorchError); } TEST(TextToImageCtorTests, InvalidDecoderPathThrows) { - EXPECT_THROW(TextToImage(kValidTokenizerPath, kValidEncoderPath, - kValidUnetPath, "nonexistent.pte", - kSchedulerBetaStart, kSchedulerBetaEnd, + EXPECT_THROW(TextToImage(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, + "nonexistent.pte", kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, createMockCallInvoker()), RnExecutorchError); } TEST(TextToImageGenerateTests, InvalidImageSizeThrows) { - TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, - kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, - kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, - createMockCallInvoker()); - EXPECT_THROW((void)model.generate("a cat", 100, 1, 42, nullptr), - RnExecutorchError); + TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, kValidDecoderPath, + kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, + kSchedulerStepsOffset, createMockCallInvoker()); + EXPECT_THROW((void)model.generate("a cat", 100, 1, 42, nullptr), RnExecutorchError); } TEST(TextToImageGenerateTests, EmptyPromptThrows) { - TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, - kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, - kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, - createMockCallInvoker()); - EXPECT_THROW((void)model.generate("", 128, 1, 42, nullptr), - RnExecutorchError); + TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, kValidDecoderPath, + kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, + kSchedulerStepsOffset, createMockCallInvoker()); + EXPECT_THROW((void)model.generate("", 128, 1, 42, nullptr), RnExecutorchError); } TEST(TextToImageGenerateTests, ZeroStepsThrows) { - TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, - kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, - kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, - createMockCallInvoker()); - EXPECT_THROW((void)model.generate("a cat", 128, 0, 42, nullptr), - RnExecutorchError); + TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, kValidDecoderPath, + kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, + kSchedulerStepsOffset, createMockCallInvoker()); + EXPECT_THROW((void)model.generate("a cat", 128, 0, 42, nullptr), RnExecutorchError); } TEST(TextToImageGenerateTests, GenerateReturnsFileUri) { // TODO: Investigate source of the issue GTEST_SKIP() << "Skipping TextToImage generation test in emulator " "environment due to UNet forward call throwing error no. 1"; - TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, - kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, - kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, - createMockCallInvoker()); + TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, kValidDecoderPath, + kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, + kSchedulerStepsOffset, createMockCallInvoker()); auto result = model.generate("a cat", 128, 1, 42, nullptr); EXPECT_FALSE(result.empty()); EXPECT_TRUE(result.starts_with("file://")); @@ -130,10 +116,9 @@ TEST(TextToImageGenerateTests, SameSeedProducesSameResult) { // TODO: Investigate source of the issue GTEST_SKIP() << "Skipping TextToImage generation test in emulator " "environment due to UNet forward call throwing error no. 1"; - TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, - kValidDecoderPath, kSchedulerBetaStart, kSchedulerBetaEnd, - kSchedulerNumTrainTimesteps, kSchedulerStepsOffset, - createMockCallInvoker()); + TextToImage model(kValidTokenizerPath, kValidEncoderPath, kValidUnetPath, kValidDecoderPath, + kSchedulerBetaStart, kSchedulerBetaEnd, kSchedulerNumTrainTimesteps, + kSchedulerStepsOffset, createMockCallInvoker()); auto path1 = model.generate("a cat", 128, 1, 42, nullptr); auto path2 = model.generate("a cat", 128, 1, 42, nullptr); ASSERT_FALSE(path1.empty()); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToSpeechTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToSpeechTest.cpp index bb1a201ebc..978c6b91b9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToSpeechTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextToSpeechTest.cpp @@ -28,12 +28,11 @@ bool isAudioValid(const std::vector &audio) { } // Check for non-silence (amplitude greater than an arbitrary small noise // threshold) - return std::ranges::any_of( - audio, [](float sample) { return std::abs(sample) > 1e-4f; }); + return std::ranges::any_of(audio, [](float sample) { return std::abs(sample) > 1e-4f; }); } -bool isAudioSimilar(const std::vector &audio1, - const std::vector &audio2, float tolerance = 0.1f) { +bool isAudioSimilar(const std::vector &audio1, const std::vector &audio2, + float tolerance = 0.1f) { if (audio1.empty() || audio2.empty()) { return false; } @@ -42,10 +41,8 @@ bool isAudioSimilar(const std::vector &audio1, size_t steps = std::max(audio1.size(), audio2.size()); for (size_t i = 0; i < steps; ++i) { - size_t idx1 = - static_cast((static_cast(i) / steps) * audio1.size()); - size_t idx2 = - static_cast((static_cast(i) / steps) * audio2.size()); + size_t idx1 = static_cast((static_cast(i) / steps) * audio1.size()); + size_t idx2 = static_cast((static_cast(i) / steps) * audio2.size()); float diff = audio1[idx1] - audio2[idx2]; sumSqDiff += diff * diff; @@ -53,8 +50,8 @@ bool isAudioSimilar(const std::vector &audio1, double rmse = std::sqrt(sumSqDiff / steps); if (rmse >= tolerance) { - std::cerr << "Audio structural RMSE difference: " << rmse - << " (tolerance: " << tolerance << ")" << std::endl; + std::cerr << "Audio structural RMSE difference: " << rmse << " (tolerance: " << tolerance << ")" + << std::endl; return false; } return true; @@ -64,9 +61,9 @@ class KokoroTest : public ::testing::Test { protected: void SetUp() override { try { - model_ = std::make_unique( - kValidLang, kValidTaggerPath, kValidLexiconPath, kValidPhonemizerPath, - kValidDurationPath, kValidSynthesizerPath, kValidVoicePath, nullptr); + model_ = std::make_unique(kValidLang, kValidTaggerPath, kValidLexiconPath, + kValidPhonemizerPath, kValidDurationPath, + kValidSynthesizerPath, kValidVoicePath, nullptr); } catch (...) { model_ = nullptr; } @@ -77,9 +74,8 @@ class KokoroTest : public ::testing::Test { } // namespace TEST(TTSCtorTests, InvalidVoicePathThrows) { - EXPECT_THROW(Kokoro(kValidLang, kValidTaggerPath, kValidLexiconPath, - kValidPhonemizerPath, kValidDurationPath, - kValidSynthesizerPath, "nonexistent_voice.bin", nullptr), + EXPECT_THROW(Kokoro(kValidLang, kValidTaggerPath, kValidLexiconPath, kValidPhonemizerPath, + kValidDurationPath, kValidSynthesizerPath, "nonexistent_voice.bin", nullptr), RnExecutorchError); } @@ -106,8 +102,7 @@ TEST_F(KokoroTest, GenerateReturnsValidAudio) { auto result = model_->generate(U"Hello world! How are you doing?", 1.0f); auto reference = test_utils::loadAudioFromFile("test_speech.raw"); - ASSERT_FALSE(reference.empty()) - << "Reference audio 'test_speech.raw' not found."; + ASSERT_FALSE(reference.empty()) << "Reference audio 'test_speech.raw' not found."; // Compare against an audio waveform obtained from the original // Kokoro model (PyTorch) @@ -126,4 +121,4 @@ TEST_F(KokoroTest, GenerateSpeedAdjustsAudioLength) { EXPECT_TRUE(isAudioValid(resultFast)); // Fast speech should result in a noticeably shorter output waveform EXPECT_LT(resultFast.size(), resultNormal.size()); -} \ No newline at end of file +} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp index 393ce78d39..b41b6f42d0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TokenizerModuleTest.cpp @@ -7,8 +7,7 @@ using namespace rnexecutorch; constexpr auto kValidTokenizerPath = "tokenizer.json"; TEST(TokenizerCtorTests, InvalidPathThrows) { - EXPECT_THROW(TokenizerModule("nonexistent_tokenizer.json", nullptr), - RnExecutorchError); + EXPECT_THROW(TokenizerModule("nonexistent_tokenizer.json", nullptr), RnExecutorchError); } TEST(TokenizerCtorTests, ValidPathDoesntThrow) { diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp index c92abc0f15..caa788e601 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VerticalOCRTest.cpp @@ -20,10 +20,9 @@ constexpr auto kValidVerticalTestImagePath = "file:///data/local/tmp/rnexecutorch_tests/we_are_software_mansion.jpg"; // English alphabet symbols (must match alphabets.english from symbols.ts) -const std::string ENGLISH_SYMBOLS = - "0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ " - "\xE2\x82\xAC" // Euro sign (€) - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; +const std::string ENGLISH_SYMBOLS = "0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ " + "\xE2\x82\xAC" // Euro sign (€) + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; // ============================================================================ // Common tests via typed test suite @@ -33,14 +32,12 @@ template <> struct ModelTraits { using ModelType = VerticalOCR; static ModelType createValid() { - return ModelType(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, - rnexecutorch::createMockCallInvoker()); + return ModelType(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, rnexecutorch::createMockCallInvoker()); } static ModelType createInvalid() { - return ModelType("nonexistent.pte", kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, + return ModelType("nonexistent.pte", kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, rnexecutorch::createMockCallInvoker()); } @@ -59,61 +56,56 @@ INSTANTIATE_TYPED_TEST_SUITE_P(VerticalOCR, CommonModelTest, VerticalOCRTypes); // Constructor tests TEST(VerticalOCRCtorTests, InvalidRecognizerPathThrows) { - EXPECT_THROW(VerticalOCR(kValidVerticalDetectorPath, "nonexistent.pte", - ENGLISH_SYMBOLS, false, createMockCallInvoker()), + EXPECT_THROW(VerticalOCR(kValidVerticalDetectorPath, "nonexistent.pte", ENGLISH_SYMBOLS, false, + createMockCallInvoker()), RnExecutorchError); } TEST(VerticalOCRCtorTests, EmptySymbolsThrows) { - EXPECT_THROW(VerticalOCR(kValidVerticalDetectorPath, - kValidVerticalRecognizerPath, "", false, + EXPECT_THROW(VerticalOCR(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, "", false, createMockCallInvoker()), RnExecutorchError); } TEST(VerticalOCRCtorTests, IndependentCharsTrueDoesntThrow) { - EXPECT_NO_THROW(VerticalOCR(kValidVerticalDetectorPath, - kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, - true, createMockCallInvoker())); + EXPECT_NO_THROW(VerticalOCR(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, + ENGLISH_SYMBOLS, true, createMockCallInvoker())); } TEST(VerticalOCRCtorTests, IndependentCharsFalseDoesntThrow) { - EXPECT_NO_THROW(VerticalOCR(kValidVerticalDetectorPath, - kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, - false, createMockCallInvoker())); + EXPECT_NO_THROW(VerticalOCR(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, + ENGLISH_SYMBOLS, false, createMockCallInvoker())); } // Generate tests - Independent Characters strategy TEST(VerticalOCRGenerateTests, IndependentCharsInvalidImageThrows) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), - RnExecutorchError); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, + createMockCallInvoker()); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsEmptyImagePathThrows) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, true, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, + createMockCallInvoker()); EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsMalformedURIThrows) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, true, createMockCallInvoker()); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), - RnExecutorchError); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, + createMockCallInvoker()); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, IndependentCharsValidImageReturnsResults) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, true, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, + createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); EXPECT_GE(results.size(), 0u); } TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidBBoxes) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, true, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, + createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { @@ -125,8 +117,8 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidBBoxes) { } TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidScores) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, true, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, + createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { @@ -136,8 +128,8 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveValidScores) { } TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveNonEmptyText) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, true, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, true, + createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { @@ -147,35 +139,33 @@ TEST(VerticalOCRGenerateTests, IndependentCharsDetectionsHaveNonEmptyText) { // Generate tests - Joint Characters strategy TEST(VerticalOCRGenerateTests, JointCharsInvalidImageThrows) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), - RnExecutorchError); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); + EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsEmptyImagePathThrows) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsMalformedURIThrows) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); - EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), - RnExecutorchError); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); + EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"), RnExecutorchError); } TEST(VerticalOCRGenerateTests, JointCharsValidImageReturnsResults) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); EXPECT_GE(results.size(), 0u); } TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidBBoxes) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { @@ -187,8 +177,8 @@ TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidBBoxes) { } TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidScores) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { @@ -198,8 +188,8 @@ TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveValidScores) { } TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveNonEmptyText) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); auto results = model.generateFromString(kValidVerticalTestImagePath); for (const auto &detection : results) { @@ -209,31 +199,23 @@ TEST(VerticalOCRGenerateTests, JointCharsDetectionsHaveNonEmptyText) { // Strategy comparison tests TEST(VerticalOCRStrategyTests, BothStrategiesRunSuccessfully) { - VerticalOCR independentModel(kValidVerticalDetectorPath, - kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, - true, createMockCallInvoker()); - VerticalOCR jointModel(kValidVerticalDetectorPath, - kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, - createMockCallInvoker()); + VerticalOCR independentModel(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, + ENGLISH_SYMBOLS, true, createMockCallInvoker()); + VerticalOCR jointModel(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); - EXPECT_NO_THROW( - (void)independentModel.generateFromString(kValidVerticalTestImagePath)); - EXPECT_NO_THROW( - (void)jointModel.generateFromString(kValidVerticalTestImagePath)); + EXPECT_NO_THROW((void)independentModel.generateFromString(kValidVerticalTestImagePath)); + EXPECT_NO_THROW((void)jointModel.generateFromString(kValidVerticalTestImagePath)); } TEST(VerticalOCRStrategyTests, BothStrategiesReturnValidResults) { - VerticalOCR independentModel(kValidVerticalDetectorPath, - kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, - true, createMockCallInvoker()); - VerticalOCR jointModel(kValidVerticalDetectorPath, - kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, false, - createMockCallInvoker()); - - auto independentResults = - independentModel.generateFromString(kValidVerticalTestImagePath); - auto jointResults = - jointModel.generateFromString(kValidVerticalTestImagePath); + VerticalOCR independentModel(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, + ENGLISH_SYMBOLS, true, createMockCallInvoker()); + VerticalOCR jointModel(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); + + auto independentResults = independentModel.generateFromString(kValidVerticalTestImagePath); + auto jointResults = jointModel.generateFromString(kValidVerticalTestImagePath); // Both should return some results (or none if no text detected) EXPECT_GE(independentResults.size(), 0u); @@ -244,11 +226,10 @@ TEST(VerticalOCRStrategyTests, BothStrategiesReturnValidResults) { // generateFromPixels smoke test // ============================================================================ TEST(VerticalOCRPixelTests, ValidPixelsReturnsResults) { - VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, - ENGLISH_SYMBOLS, false, createMockCallInvoker()); + VerticalOCR model(kValidVerticalDetectorPath, kValidVerticalRecognizerPath, ENGLISH_SYMBOLS, + false, createMockCallInvoker()); std::vector buf(64 * 64 * 3, 128); - JSTensorViewIn view{ - buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; + JSTensorViewIn view{buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte}; auto results = model.generateFromPixels(view); EXPECT_GE(results.size(), 0u); } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp index 6736454d6f..8fe5c15314 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VisionModelTest.cpp @@ -15,18 +15,13 @@ using executorch::aten::ScalarType; // ============================================================================ class TestableVisionModel : public VisionModel { public: - explicit TestableVisionModel(const std::string &path) - : VisionModel(path, nullptr) {} + explicit TestableVisionModel(const std::string &path) : VisionModel(path, nullptr) {} cv::Mat preprocessPublic(const cv::Mat &img) const { return preprocess(img); } - cv::Mat extractFromPixelsPublic(const JSTensorViewIn &v) const { - return extractFromPixels(v); - } + cv::Mat extractFromPixelsPublic(const JSTensorViewIn &v) const { return extractFromPixels(v); } - void setInputShape(std::vector shape) { - modelInputShape_ = std::move(shape); - } + void setInputShape(std::vector shape) { modelInputShape_ = std::move(shape); } }; // Reuse the style_transfer .pte as a vehicle — we never call forward(). @@ -37,9 +32,7 @@ constexpr auto kModelPath = "style_transfer_candy_xnnpack_fp32.pte"; // ============================================================================ class VisionModelPreprocessTest : public ::testing::Test { protected: - void SetUp() override { - model = std::make_unique(kModelPath); - } + void SetUp() override { model = std::make_unique(kModelPath); } std::unique_ptr model; }; @@ -87,9 +80,7 @@ TEST_F(VisionModelPreprocessTest, NonSquareTargetSize) { // ============================================================================ class VisionModelExtractFromPixelsTest : public ::testing::Test { protected: - void SetUp() override { - model = std::make_unique(kModelPath); - } + void SetUp() override { model = std::make_unique(kModelPath); } std::unique_ptr model; }; diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp index e60b375f1d..f36b21745d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/VoiceActivityDetectionTest.cpp @@ -29,13 +29,9 @@ namespace model_tests { template <> struct ModelTraits { using ModelType = VoiceActivityDetection; - static ModelType createValid() { - return ModelType(kValidVadModelPath, nullptr); - } + static ModelType createValid() { return ModelType(kValidVadModelPath, nullptr); } - static ModelType createInvalid() { - return ModelType("nonexistent.pte", nullptr); - } + static ModelType createInvalid() { return ModelType("nonexistent.pte", nullptr); } static void callGenerate(ModelType &model) { auto audio = loadAudioFromFile("test_audio_float.raw"); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h b/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h index 023d6cf080..88b562ec04 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/StubRunner.h @@ -12,9 +12,9 @@ class StubRunner : public ::executorch::extension::llm::BaseLLMRunner { loaded_ = true; return ::executorch::runtime::Error::Ok; } - ::executorch::runtime::Error generate_internal( - const std::vector<::executorch::extension::llm::MultimodalInput> &, - std::function) override { + ::executorch::runtime::Error + generate_internal(const std::vector<::executorch::extension::llm::MultimodalInput> &, + std::function) override { return ::executorch::runtime::Error::Ok; } void stop_impl() override {} diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp index 897a2778e8..ba3d2b55a4 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp @@ -41,9 +41,7 @@ namespace rnexecutorch { // Stub for fetchUrlFunc - used by ImageProcessing for remote URLs // Tests only use local files, so this is never called using FetchUrlFunc_t = std::function(std::string)>; -FetchUrlFunc_t fetchUrlFunc = [](std::string) -> std::vector { - return {}; -}; +FetchUrlFunc_t fetchUrlFunc = [](std::string) -> std::vector { return {}; }; // Global mock call invoker for tests std::shared_ptr createMockCallInvoker() { diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FileUtilsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FileUtilsTest.cpp index ed9d802361..c369a5224f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FileUtilsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FileUtilsTest.cpp @@ -26,7 +26,6 @@ TEST_F(FileIOTest, LoadBytesFromFileSuccessfully) { } TEST_F(FileIOTest, LoadBytesFromFileFailOnNonExistentFile) { - EXPECT_THROW( - { loadBytesFromFile("non_existent_file.txt"); }, RnExecutorchError); + EXPECT_THROW({ loadBytesFromFile("non_existent_file.txt"); }, RnExecutorchError); } } // namespace rnexecutorch::file_utils diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp index cfea1eb2a4..bf3c6334b0 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/FrameProcessorTest.cpp @@ -9,8 +9,7 @@ using namespace rnexecutorch; using namespace rnexecutorch::utils; using executorch::aten::ScalarType; -static JSTensorViewIn makeValidView(std::vector &buf, int32_t h, - int32_t w) { +static JSTensorViewIn makeValidView(std::vector &buf, int32_t h, int32_t w) { buf.assign(static_cast(h * w * 3), 128); return JSTensorViewIn{buf.data(), {h, w, 3}, ScalarType::Byte}; } diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/ImageProcessingTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/ImageProcessingTest.cpp index d8a5a2a7fe..a6ff87f46f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/ImageProcessingTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/ImageProcessingTest.cpp @@ -36,16 +36,14 @@ TEST(ReadImageTest, WorksWithRawBase64Content) { } TEST(ReadImageTest, FailsForInvalidBase64UriFormat) { - std::string invalidUri = - "data:image/jpeg;base64,extra,comma," + RAW_BASE64_JPEG; + std::string invalidUri = "data:image/jpeg;base64,extra,comma," + RAW_BASE64_JPEG; EXPECT_THROW({ readImage(invalidUri); }, RnExecutorchError); try { readImage(invalidUri); } catch (const RnExecutorchError &e) { - EXPECT_EQ(e.getNumericCode(), - static_cast(RnExecutorchErrorCode::FileReadFailed)); + EXPECT_EQ(e.getNumericCode(), static_cast(RnExecutorchErrorCode::FileReadFailed)); } } @@ -60,4 +58,4 @@ TEST(ReadImageTest, FailsForInvalidBase64Data) { EXPECT_STREQ(e.what(), "Read image error: invalid argument"); } } -} // namespace rnexecutorch::image_processing \ No newline at end of file +} // namespace rnexecutorch::image_processing diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/LogTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/LogTest.cpp index 1fb2f3fdf6..f8e8cee49f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/LogTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/LogTest.cpp @@ -26,15 +26,13 @@ class TestValue : public ::testing::Test { TestValue() { oss << std::boolalpha; } template - void testValueViaComparison(const T &value, - const std::string &expectedOutput) { + void testValueViaComparison(const T &value, const std::string &expectedOutput) { printElement(oss, value); EXPECT_EQ(oss.str(), expectedOutput); clearOutputStream(oss); } - template - void testValueViaRegex(const T &value, const std::string &expectedPattern) { + template void testValueViaRegex(const T &value, const std::string &expectedPattern) { printElement(oss, value); const std::regex pattern(expectedPattern); EXPECT_TRUE(std::regex_search(oss.str(), pattern)) @@ -88,9 +86,7 @@ class Point final { int x, y; }; -TEST_F(DirectStreamableElementsPrintTest, HandlesIntegers) { - testValueViaComparison(123, "123"); -} +TEST_F(DirectStreamableElementsPrintTest, HandlesIntegers) { testValueViaComparison(123, "123"); } TEST_F(DirectStreamableElementsPrintTest, HandlesStrings) { testValueViaComparison(std::string("Hello World"), "Hello World"); @@ -110,9 +106,7 @@ TEST_F(DirectStreamableElementsPrintTest, HandlesBooleans) { testValueViaComparison(false, "false"); } -TEST_F(DirectStreamableElementsPrintTest, HandlesChar) { - testValueViaComparison('a', "a"); -} +TEST_F(DirectStreamableElementsPrintTest, HandlesChar) { testValueViaComparison('a', "a"); } TEST_F(DirectStreamableElementsPrintTest, HandlesCharPointer) { const char *word = "Hello World"; @@ -149,13 +143,11 @@ TEST_F(DirectStreamableElementsPrintTest, handlesStaticArrayOfChars) { // log handles operator<<(&ostream) for std::tuple TEST_F(DirectStreamableElementsPrintTest, HandlesStdTuple) { - const std::tuple tupleOfDifferentTypes = { - 42, "Tuple", 3.14}; + const std::tuple tupleOfDifferentTypes = {42, "Tuple", 3.14}; testValueViaComparison(tupleOfDifferentTypes, "<42, Tuple, 3.14>"); // All empty or zero-initialized elements of tuple - const std::tuple zeroInitializedTuple = {"", 0, - 0.0f}; + const std::tuple zeroInitializedTuple = {"", 0, 0.0f}; testValueViaComparison(zeroInitializedTuple, "<, 0, 0>"); // Nested tuple @@ -197,8 +189,8 @@ TEST_F(ContainerPrintTest, HandlesUnorderedSet) { } TEST_F(ContainerPrintTest, HandlesUnorderedMultimap) { - const std::unordered_multimap unorderedMultimapStringToInt = - {{"one", 1}, {"one", 2}, {"two", 2}}; + const std::unordered_multimap unorderedMultimapStringToInt = { + {"one", 1}, {"one", 2}, {"two", 2}}; std::string pattern = R"(\[\s*)"; // construct regex by adding each permutation pattern += R"((?:\(one, 1\),\s*\(one, 2\),\s*\(two, 2\)|)"; @@ -253,15 +245,14 @@ TEST_F(ContainerPrintTest, HandlesMultiset) { } TEST_F(ContainerPrintTest, HandlesMultimap) { - const std::multimap multimapStringToInt = { - {"one", 1}, {"one", 2}, {"two", 2}}; + const std::multimap multimapStringToInt = {{"one", 1}, {"one", 2}, {"two", 2}}; testValueViaComparison(multimapStringToInt, "[(one, 1), (one, 2), (two, 2)]"); } TEST_F(ContainerPrintTest, HandlesSpan) { std::vector vectorOfInts = {1, 2, 3, 4}; - const std::span spanOnVector( - vectorOfInts.begin(), vectorOfInts.end()); // Create a span from a vector + const std::span spanOnVector(vectorOfInts.begin(), + vectorOfInts.end()); // Create a span from a vector testValueViaComparison(spanOnVector, "[1, 2, 3, 4]"); } @@ -275,8 +266,7 @@ TEST_F(NestedContainerPrintTest, HandlesListOfQueuesOfPoints) { listOfQueues.front().push(Point(1, 1)); listOfQueues.front().push(Point(2, 2)); listOfQueues.front().push(Point(3, 3)); - testValueViaComparison(listOfQueues, - "[[Point(1, 1), Point(2, 2), Point(3, 3)]]"); + testValueViaComparison(listOfQueues, "[[Point(1, 1), Point(2, 2), Point(3, 3)]]"); } TEST_F(NestedContainerPrintTest, HandlesNestedVectors) { @@ -287,24 +277,21 @@ TEST_F(NestedContainerPrintTest, HandlesNestedVectors) { TEST_F(NestedContainerPrintTest, HandlesMapOfVectorOfPoints) { const std::map> mapOfVectors = { {"first", {Point(1, 2)}}, {"second", {Point(3, 4), Point(5, 6)}}}; - testValueViaComparison( - mapOfVectors, - "[(first, [Point(1, 2)]), (second, [Point(3, 4), Point(5, 6)])]"); + testValueViaComparison(mapOfVectors, + "[(first, [Point(1, 2)]), (second, [Point(3, 4), Point(5, 6)])]"); } TEST_F(NestedContainerPrintTest, HandlesVectorOfMaps) { - const std::vector> vectorOfMaps = { - {{"one", 1}, {"two", 2}}, {{"three", 3}, {"four", 4}}}; + const std::vector> vectorOfMaps = {{{"one", 1}, {"two", 2}}, + {{"three", 3}, {"four", 4}}}; // word "three" is lexicographically smaller than "four" - testValueViaComparison(vectorOfMaps, - "[[(one, 1), (two, 2)], [(four, 4), (three, 3)]]"); + testValueViaComparison(vectorOfMaps, "[[(one, 1), (two, 2)], [(four, 4), (three, 3)]]"); } TEST_F(NestedContainerPrintTest, HandlesComplexNestedStructures) { - const std::vector>>> - complexNested = {{{"first", {{1, 2}, {3}}}, {"second", {{4}}}}}; - testValueViaComparison(complexNested, - "[[(first, [[1, 2], [3]]), (second, [[4]])]]"); + const std::vector>>> complexNested = { + {{"first", {{1, 2}, {3}}}, {"second", {{4}}}}}; + testValueViaComparison(complexNested, "[[(first, [[1, 2], [3]]), (second, [[4]])]]"); } TEST_F(EgdeCasesPrintTest, HandleEmptyContainer) { @@ -347,12 +334,9 @@ TEST_F(VariantPrintTest, HandlesVariant) { } TEST_F(ErrorHandlingPrintTest, HandlesErrorCode) { - const auto errorCodeValue = - std::make_error_code(std::errc::function_not_supported).value(); - const std::error_code errorCode = - make_error_code(std::errc::function_not_supported); - testValueViaComparison( - errorCode, "ErrorCode(" + std::to_string(errorCodeValue) + ", generic)"); + const auto errorCodeValue = std::make_error_code(std::errc::function_not_supported).value(); + const std::error_code errorCode = make_error_code(std::errc::function_not_supported); + testValueViaComparison(errorCode, "ErrorCode(" + std::to_string(errorCodeValue) + ", generic)"); } TEST_F(ErrorHandlingPrintTest, HandlesExceptionPtr) { @@ -371,8 +355,7 @@ TEST_F(FileSystemPrintTest, HandlesPath) { TEST_F(FileSystemPrintTest, HandlesDirectoryIterator) { // Setup a temporary directory and files within - std::filesystem::path directory = - std::filesystem::temp_directory_path() / "test_dir"; + std::filesystem::path directory = std::filesystem::temp_directory_path() / "test_dir"; std::filesystem::create_directory(directory); std::ofstream(directory / "file1.txt"); @@ -381,8 +364,7 @@ TEST_F(FileSystemPrintTest, HandlesDirectoryIterator) { std::filesystem::directory_iterator begin(directory); testValueViaRegex( - begin, - R"(Directory\["file1.txt", "file2.txt"\]|Directory\["file2.txt", "file1.txt"\])"); + begin, R"(Directory\["file1.txt", "file2.txt"\]|Directory\["file2.txt", "file1.txt"\])"); // Cleanup std::filesystem::remove_all(directory); @@ -444,8 +426,7 @@ TEST_F(BufferTest, MessageLongerThanLimit) { class LoggingTest : public ::testing::Test { protected: - template - void testLoggingDoesNotChangeContainer(const T &original) { + template void testLoggingDoesNotChangeContainer(const T &original) { const auto copy = original; // Make a copy of the container log(LOG_LEVEL::Info, original); ASSERT_TRUE(check_if_same_content(original, copy)) @@ -488,9 +469,7 @@ TEST_F(LoggingTest, LoggingDoesNotChangeVector) { testLoggingDoesNotChangeContainer(original); } -TEST(LogFunctionTest, LoggingBasic) { - EXPECT_NO_THROW(log(LOG_LEVEL::Debug, "Test123")); -} +TEST(LogFunctionTest, LoggingBasic) { EXPECT_NO_THROW(log(LOG_LEVEL::Debug, "Test123")); } TEST(LogFunctionTest, LoggingWithNonDefaultLogSize) { constexpr std::size_t sizeBiggerThanDefault = 2048; @@ -502,8 +481,7 @@ TEST(LogFunctionTest, LoggingMoreThanOneElement) { constexpr auto testStringLiteral = "Test123"; const auto testVector = std::vector{1, 2, 3, 4}; const auto testPair = std::pair(1, 2.0); - EXPECT_NO_THROW( - log(LOG_LEVEL::Debug, testStringLiteral, testVector, testPair)); + EXPECT_NO_THROW(log(LOG_LEVEL::Debug, testStringLiteral, testVector, testPair)); } TEST(MovingSequencable, MovingSequencableTest) { diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/NumericalTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/NumericalTest.cpp index fa7f8cfddb..48eb5da328 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/NumericalTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/NumericalTest.cpp @@ -8,8 +8,7 @@ namespace rnexecutorch::numerical { // Helper function to check if two float vectors are approximately equal -void expect_vectors_eq(const std::vector &vector1, - const std::vector &vector2, +void expect_vectors_eq(const std::vector &vector1, const std::vector &vector2, float atol = 1.0e-6F) { ASSERT_EQ(vector1.size(), vector2.size()); for (size_t i = 0; i < vector1.size(); i++) { @@ -40,8 +39,7 @@ TEST(NormalizeTests, NormalizeBasic) { std::vector input = {1.0F, 2.0F, 3.0F}; normalize(input); const auto normOfInput = std::sqrtf(14.0F); - const std::vector expected = {1.0F / normOfInput, 2.0F / normOfInput, - 3.0F / normOfInput}; + const std::vector expected = {1.0F / normOfInput, 2.0F / normOfInput, 3.0F / normOfInput}; expect_vectors_eq(input, expected); } @@ -65,8 +63,7 @@ TEST(NormalizeTests, NormalizationOfEmptyVector) { } TEST(MeanPoolingTests, MeanPoolingBasic) { - const std::vector modelOutputVec = {1.0F, 2.0F, 3.0F, - 4.0F, 5.0F, 6.0F}; + const std::vector modelOutputVec = {1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F}; const std::vector attnMaskVec = {1, 1, 0}; std::span modelOutput(modelOutputVec); @@ -78,8 +75,7 @@ TEST(MeanPoolingTests, MeanPoolingBasic) { } TEST(MeanPoolingTests, MeanPoolingWithZeroAttentionMask) { - const std::vector modelOutputVec = {1.0F, 2.0F, 3.0F, - 4.0F, 5.0F, 6.0F}; + const std::vector modelOutputVec = {1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F}; const std::vector attnMaskVec = {0, 0, 0}; std::span modelOutput(modelOutputVec); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp index bf7a1d02d6..bd986e9f03 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/unit/SamplerTest.cpp @@ -15,8 +15,8 @@ using namespace executorch::extension::llm; // Helper: run sampler N times, count how often each index is picked. template -std::vector sampleMany(Sampler &s, std::vector logits, - const std::vector &recent, int n) { +std::vector sampleMany(Sampler &s, std::vector logits, const std::vector &recent, + int n) { std::vector counts(logits.size(), 0); for (int i = 0; i < n; ++i) { std::vector copy = logits; @@ -42,10 +42,8 @@ TEST(SamplerTest, RepetitionPenaltyReducesPositiveLogit) { // 2000) versus the baseline e^-1 / (1 + e^-1) ≈ 0.27 (~538). A static "< 200" // bound would be mathematically unreachable at this penalty. TEST(SamplerTest, RepetitionPenaltyMultipliesNegativeLogit) { - Sampler baseline( - 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.0f}); - Sampler penalised( - 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.5f}); + Sampler baseline(2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.0f}); + Sampler penalised(2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.5f}); std::vector logits_b = {0.0f, -1.0f}; std::vector logits_p = {0.0f, -1.0f}; std::vector recent = {1}; @@ -56,10 +54,8 @@ TEST(SamplerTest, RepetitionPenaltyMultipliesNegativeLogit) { // 3. No recent tokens — penalty has no effect. TEST(SamplerTest, RepetitionPenaltyNoRecentTokensHasNoEffect) { - Sampler baseline( - 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.0f}); - Sampler penalised( - 2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 2.0f}); + Sampler baseline(2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 1.0f}); + Sampler penalised(2, {.temperature = 1.0f, .topp = 1.0f, .repetition_penalty = 2.0f}); std::vector logits_b = {1.0f, 1.0f}; std::vector logits_p = {1.0f, 1.0f}; std::vector recent = {}; diff --git a/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h b/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h index 50025eeeb7..cf3936175a 100644 --- a/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h +++ b/packages/react-native-executorch/common/rnexecutorch/threads/GlobalThreadPool.h @@ -30,15 +30,13 @@ class GlobalThreadPool { ThreadConfig config = {}) { std::call_once(initFlag, [&numThreads, config]() { if (!numThreads) { - numThreads = - ::executorch::extension::cpuinfo::get_num_performant_cores(); + numThreads = ::executorch::extension::cpuinfo::get_num_performant_cores(); } numThreads = std::max(numThreads.value(), 2u); - log(rnexecutorch::LOG_LEVEL::Info, "Initializing global thread pool with", - numThreads, "threads"); - instance = std::make_unique(numThreads.value(), - config); + log(rnexecutorch::LOG_LEVEL::Info, "Initializing global thread pool with", numThreads, + "threads"); + instance = std::make_unique(numThreads.value(), config); // Disable OpenCV's internal threading to prevent it from overriding our // thread pool configuration, which would cause degraded performance cv::setNumThreads(0); @@ -46,8 +44,7 @@ class GlobalThreadPool { } // Convenience methods that mirror std::thread interface - template - static auto async(Func &&func, Args &&...args) { + template static auto async(Func &&func, Args &&...args) { return get().submit(std::forward(func), std::forward(args)...); } @@ -58,14 +55,12 @@ class GlobalThreadPool { } // Fire and forget (like std::thread{}.detach()) - template - static void detach(Func &&func, Args &&...args) { + template static void detach(Func &&func, Args &&...args) { get().submitDetached(std::forward(func), std::forward(args)...); } // Execute and wait (like std::thread{}.join()) - template - static auto execute(Func &&func, Args &&...args) { + template static auto execute(Func &&func, Args &&...args) { return get().execute(std::forward(func), std::forward(args)...); } diff --git a/packages/react-native-executorch/common/rnexecutorch/threads/HighPerformanceThreadPool.h b/packages/react-native-executorch/common/rnexecutorch/threads/HighPerformanceThreadPool.h index 67610ec958..dc75714732 100644 --- a/packages/react-native-executorch/common/rnexecutorch/threads/HighPerformanceThreadPool.h +++ b/packages/react-native-executorch/common/rnexecutorch/threads/HighPerformanceThreadPool.h @@ -42,8 +42,7 @@ struct ThreadConfig { class HighPerformanceThreadPool { public: - explicit HighPerformanceThreadPool(size_t numThreads = 1, - ThreadConfig cfg = ThreadConfig()) + explicit HighPerformanceThreadPool(size_t numThreads = 1, ThreadConfig cfg = ThreadConfig()) : config(std::move(cfg)) { #ifdef __ANDROID__ @@ -55,16 +54,14 @@ class HighPerformanceThreadPool { workers.emplace_back(&HighPerformanceThreadPool::workerThread, this, i); } - log(LOG_LEVEL::Debug, "Thread pool initialized with", numThreads, - "workers."); + log(LOG_LEVEL::Debug, "Thread pool initialized with", numThreads, "workers."); } ~HighPerformanceThreadPool() { shutdown(); } // Submit a task and get a future for the result template - auto submit(Func &&func, Args &&...args) - -> std::future { + auto submit(Func &&func, Args &&...args) -> std::future { return submitWithPriority(Priority::NORMAL, std::forward(func), std::forward(args)...); } @@ -77,10 +74,8 @@ class HighPerformanceThreadPool { using ReturnType = decltype(func(args...)); // Create a packaged task - auto boundFunc = - std::bind(std::forward(func), std::forward(args)...); - auto task = std::make_unique>( - std::move(boundFunc)); + auto boundFunc = std::bind(std::forward(func), std::forward(args)...); + auto task = std::make_unique>(std::move(boundFunc)); auto future = task->getFuture(); // Add to queue @@ -92,8 +87,7 @@ class HighPerformanceThreadPool { "Thread pool is shutting down"); } - WorkItem item(std::move(task), priority, - std::chrono::steady_clock::now()); + WorkItem item(std::move(task), priority, std::chrono::steady_clock::now()); taskQueue.push(std::move(item)); } @@ -110,8 +104,7 @@ class HighPerformanceThreadPool { } // Fire and forget task - template - void submitDetached(Func &&func, Args &&...args) { + template void submitDetached(Func &&func, Args &&...args) { submit(std::forward(func), std::forward(args)...); // Future is destroyed, task still runs } @@ -210,8 +203,8 @@ class HighPerformanceThreadPool { const auto numOfCores = std::thread::hardware_concurrency(); for (int32_t i = 0; std::cmp_less(i, numOfCores); ++i) { - std::string path = "/sys/devices/system/cpu/cpu" + std::to_string(i) + - "/cpufreq/cpuinfo_max_freq"; + std::string path = + "/sys/devices/system/cpu/cpu" + std::to_string(i) + "/cpufreq/cpuinfo_max_freq"; std::ifstream file(path); if (!file.good()) { break; @@ -229,13 +222,11 @@ class HighPerformanceThreadPool { } // Sort by frequency - std::ranges::sort(cores, [](const CoreInfo &a, const CoreInfo &b) { - return a.maxFreq > b.maxFreq; - }); + std::ranges::sort(cores, + [](const CoreInfo &a, const CoreInfo &b) { return a.maxFreq > b.maxFreq; }); // Classify cores - const auto numOfPerfCores = - ::executorch::extension::cpuinfo::get_num_performant_cores(); + const auto numOfPerfCores = ::executorch::extension::cpuinfo::get_num_performant_cores(); constexpr float kKiloToGigaRatio = 1e6; for (int32_t i = 0; i < cores.size(); ++i) { @@ -276,8 +267,7 @@ class HighPerformanceThreadPool { setThreadPriority(); - log(LOG_LEVEL::Debug, "Worker", workerIndex, - "configured:", threadName.c_str()); + log(LOG_LEVEL::Debug, "Worker", workerIndex, "configured:", threadName.c_str()); } void setCPUAffinity() { diff --git a/packages/react-native-executorch/common/rnexecutorch/threads/utils/ThreadUtils.h b/packages/react-native-executorch/common/rnexecutorch/threads/utils/ThreadUtils.h index 664480d387..cc9fd4c276 100644 --- a/packages/react-native-executorch/common/rnexecutorch/threads/utils/ThreadUtils.h +++ b/packages/react-native-executorch/common/rnexecutorch/threads/utils/ThreadUtils.h @@ -7,8 +7,7 @@ namespace rnexecutorch::threads::utils { void unsafeSetupThreadPool(uint32_t num_of_cores = 0) { - auto num_of_perf_cores = - ::executorch::extension::cpuinfo::get_num_performant_cores(); + auto num_of_perf_cores = ::executorch::extension::cpuinfo::get_num_performant_cores(); log(LOG_LEVEL::Info, "Detected ", num_of_perf_cores, " performant cores"); // setting num_of_cores to floor(num_of_perf_cores / 2) + 1) because // depending on cpu arch as when possible we want to leave at least 2 @@ -17,13 +16,11 @@ void unsafeSetupThreadPool(uint32_t num_of_cores = 0) { // cores, and for newer ones (like OnePlus 12) resolves to 4, which when // benchmarked gives highest throughput. For iPhones they usually have 2 // performance cores - auto _num_of_cores = num_of_cores - ? num_of_cores - : static_cast(num_of_perf_cores / 2) + 1; + auto _num_of_cores = + num_of_cores ? num_of_cores : static_cast(num_of_perf_cores / 2) + 1; const auto threadpool = ::executorch::extension::threadpool::get_threadpool(); threadpool->_unsafe_reset_threadpool(_num_of_cores); - log(LOG_LEVEL::Info, "Configuring xnnpack for", - threadpool->get_thread_count(), "threads"); + log(LOG_LEVEL::Info, "Configuring xnnpack for", threadpool->get_thread_count(), "threads"); } } // namespace rnexecutorch::threads::utils diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp index d14c522184..476eb0299d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameExtractor.cpp @@ -30,21 +30,20 @@ cv::Mat extractFromCVPixelBuffer(void *pixelBuffer) { cv::Mat mat; if (pixelFormat == kCVPixelFormatType_32BGRA) { - mat = cv::Mat(static_cast(height), static_cast(width), CV_8UC4, - baseAddress, bytesPerRow); + mat = cv::Mat(static_cast(height), static_cast(width), CV_8UC4, baseAddress, + bytesPerRow); } else if (pixelFormat == kCVPixelFormatType_32RGBA) { - mat = cv::Mat(static_cast(height), static_cast(width), CV_8UC4, - baseAddress, bytesPerRow); + mat = cv::Mat(static_cast(height), static_cast(width), CV_8UC4, baseAddress, + bytesPerRow); } else if (pixelFormat == kCVPixelFormatType_24RGB) { - mat = cv::Mat(static_cast(height), static_cast(width), CV_8UC3, - baseAddress, bytesPerRow); + mat = cv::Mat(static_cast(height), static_cast(width), CV_8UC3, baseAddress, + bytesPerRow); } else { CVPixelBufferUnlockBaseAddress(buffer, kCVPixelBufferLock_ReadOnly); char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported CVPixelBuffer format: %u", pixelFormat); - throw RnExecutorchError(RnExecutorchErrorCode::PlatformNotSupported, - errorMessage); + std::snprintf(errorMessage, sizeof(errorMessage), "Unsupported CVPixelBuffer format: %u", + pixelFormat); + throw RnExecutorchError(RnExecutorchErrorCode::PlatformNotSupported, errorMessage); } CVPixelBufferUnlockBaseAddress(buffer, kCVPixelBufferLock_ReadOnly); @@ -62,12 +61,11 @@ cv::Mat extractFromAHardwareBuffer(void *hardwareBuffer) { AHardwareBuffer_describe(buffer, &desc); void *data = nullptr; - int lockResult = AHardwareBuffer_lock( - buffer, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, -1, nullptr, &data); + int lockResult = + AHardwareBuffer_lock(buffer, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN, -1, nullptr, &data); if (lockResult != 0) { - throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, - "Failed to lock AHardwareBuffer"); + throw RnExecutorchError(RnExecutorchErrorCode::UnknownError, "Failed to lock AHardwareBuffer"); } cv::Mat mat; @@ -81,10 +79,9 @@ cv::Mat extractFromAHardwareBuffer(void *hardwareBuffer) { } else { AHardwareBuffer_unlock(buffer, nullptr); char errorMessage[100]; - std::snprintf(errorMessage, sizeof(errorMessage), - "Unsupported AHardwareBuffer format: %u", desc.format); - throw RnExecutorchError(RnExecutorchErrorCode::PlatformNotSupported, - errorMessage); + std::snprintf(errorMessage, sizeof(errorMessage), "Unsupported AHardwareBuffer format: %u", + desc.format); + throw RnExecutorchError(RnExecutorchErrorCode::PlatformNotSupported, errorMessage); } AHardwareBuffer_unlock(buffer, nullptr); diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp index 48be09801e..e3ff600a29 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.cpp @@ -20,8 +20,7 @@ cv::Mat extractFrame(jsi::Runtime &runtime, const jsi::Object &frameData) { } auto nativeBufferValue = frameData.getProperty(runtime, "nativeBuffer"); - uint64_t bufferPtr = static_cast( - nativeBufferValue.asBigInt(runtime).asUint64(runtime)); + uint64_t bufferPtr = static_cast(nativeBufferValue.asBigInt(runtime).asUint64(runtime)); return extractFromNativeBuffer(bufferPtr); } @@ -31,8 +30,7 @@ cv::Mat frameToMat(jsi::Runtime &runtime, const jsi::Value &frameData) { return extractFrame(runtime, frameObj); } -FrameOrientation readFrameOrientation(jsi::Runtime &runtime, - const jsi::Value &frameData) { +FrameOrientation readFrameOrientation(jsi::Runtime &runtime, const jsi::Value &frameData) { auto obj = frameData.asObject(runtime); std::string orientStr = "up"; @@ -61,8 +59,7 @@ cv::Mat pixelsToMat(const JSTensorViewIn &pixelData) { "Invalid pixel data: sizes must have 3 elements " "[height, width, channels], got %zu", pixelData.sizes.size()); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, errorMessage); } int32_t height = pixelData.sizes[0]; @@ -72,16 +69,13 @@ cv::Mat pixelsToMat(const JSTensorViewIn &pixelData) { if (channels != 3) { char errorMessage[100]; std::snprintf(errorMessage, sizeof(errorMessage), - "Invalid pixel data: expected 3 channels (RGB), got %d", - channels); - throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, - errorMessage); + "Invalid pixel data: expected 3 channels (RGB), got %d", channels); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, errorMessage); } if (pixelData.scalarType != executorch::aten::ScalarType::Byte) { - throw RnExecutorchError( - RnExecutorchErrorCode::InvalidUserInput, - "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); + throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput, + "Invalid pixel data: scalarType must be BYTE (Uint8Array)"); } auto *dataPtr = static_cast(pixelData.dataPtr); diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h index b44f040d41..5fbbb58468 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameProcessor.h @@ -35,8 +35,7 @@ cv::Mat frameToMat(jsi::Runtime &runtime, const jsi::Value &frameData); * Falls back to "up"/false if fields are absent (e.g. when * enablePhysicalBufferRotation is used — transform will be a no-op). */ -FrameOrientation readFrameOrientation(jsi::Runtime &runtime, - const jsi::Value &frameData); +FrameOrientation readFrameOrientation(jsi::Runtime &runtime, const jsi::Value &frameData); /** * @brief Validate a JSTensorViewIn and wrap its data in a RGB cv::Mat. diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp index e9cf1e9d73..3d1f190209 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.cpp @@ -2,8 +2,7 @@ namespace rnexecutorch::utils { -cv::Mat rotateFrameForModel(const cv::Mat &mat, - const FrameOrientation &orient) { +cv::Mat rotateFrameForModel(const cv::Mat &mat, const FrameOrientation &orient) { if (!orient.isMirrored && orient.orientation == Orientation::Up) { return mat; } @@ -40,8 +39,8 @@ cv::Mat rotateFrameForModel(const cv::Mat &mat, return result; } -void inverseRotateBbox(computer_vision::BBox &bbox, - const FrameOrientation &orient, cv::Size rotatedSize) { +void inverseRotateBbox(computer_vision::BBox &bbox, const FrameOrientation &orient, + cv::Size rotatedSize) { const float w = static_cast(rotatedSize.width); const float h = static_cast(rotatedSize.height); @@ -83,8 +82,8 @@ void inverseRotateBbox(computer_vision::BBox &bbox, if (orient.isMirrored) { // After CW/CCW rotation (Up/Down) screen dims are swapped: rH × rW. // After no-op/180° (Left/Right) screen dims are unchanged: rW × rH. - bool swapped = (orient.orientation == Orientation::Up || - orient.orientation == Orientation::Down); + bool swapped = + (orient.orientation == Orientation::Up || orient.orientation == Orientation::Down); float sw = swapped ? h : w; float sh = swapped ? w : h; float nx1 = sw - bbox.p2.x, ny1 = sh - bbox.p2.y; diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h index 8f9ca46cc2..f733b8baa1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/FrameTransform.h @@ -49,8 +49,8 @@ cv::Mat rotateFrameForModel(const cv::Mat &mat, const FrameOrientation &orient); * Inverse of rotateFrameForModel for coordinates. * rotatedSize is the rotated frame size (rotated.size()). */ -void inverseRotateBbox(computer_vision::BBox &bbox, - const FrameOrientation &orient, cv::Size rotatedSize); +void inverseRotateBbox(computer_vision::BBox &bbox, const FrameOrientation &orient, + cv::Size rotatedSize); /** * @brief Rotate a cv::Mat from rotated-frame space back to screen space. @@ -82,8 +82,7 @@ concept Point2D = requires(P &p) { */ template requires Point2D -void inverseRotatePoints(Points &points, const FrameOrientation &orient, - cv::Size rotatedSize) { +void inverseRotatePoints(Points &points, const FrameOrientation &orient, cv::Size rotatedSize) { const float w = static_cast(rotatedSize.width); const float h = static_cast(rotatedSize.height); @@ -120,8 +119,8 @@ void inverseRotatePoints(Points &points, const FrameOrientation &orient, #if defined(__APPLE__) if (orient.isMirrored) { - bool swapped = (orient.orientation == Orientation::Up || - orient.orientation == Orientation::Down); + bool swapped = + (orient.orientation == Orientation::Up || orient.orientation == Orientation::Down); float sw = swapped ? h : w; float sh = swapped ? w : h; for (auto &p : points) { diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h index 3bd3022d4a..6ef4ef5348 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Processing.h @@ -14,8 +14,7 @@ std::vector nonMaxSuppression(std::vector items, double iouThreshold) { return {}; } - std::ranges::sort(items, - [](const T &a, const T &b) { return a.score > b.score; }); + std::ranges::sort(items, [](const T &a, const T &b) { return a.score > b.score; }); std::vector result; std::vector suppressed(items.size(), false); diff --git a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h index 7698d9807f..2aa1274963 100644 --- a/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h +++ b/packages/react-native-executorch/common/rnexecutorch/utils/computer_vision/Types.h @@ -14,13 +14,10 @@ struct BBox { float height() const { return p2.y - p1.y; } float area() const { return width() * height(); } - bool isValid() const { - return p2.x > p1.x && p2.y > p1.y && p1.x >= 0.0f && p1.y >= 0.0f; - } + bool isValid() const { return p2.x > p1.x && p2.y > p1.y && p1.x >= 0.0f && p1.y >= 0.0f; } BBox scale(float widthRatio, float heightRatio) const { - return {{p1.x * widthRatio, p1.y * heightRatio}, - {p2.x * widthRatio, p2.y * heightRatio}}; + return {{p1.x * widthRatio, p1.y * heightRatio}, {p2.x * widthRatio, p2.y * heightRatio}}; } Point p1, p2; diff --git a/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.h b/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.h index d01236e994..e7cc335e21 100644 --- a/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.h +++ b/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.h @@ -2,7 +2,6 @@ #import #import -@interface ETInstaller - : RCTEventEmitter +@interface ETInstaller : RCTEventEmitter @end diff --git a/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.mm b/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.mm index 2a8bc519ba..381cd17de3 100644 --- a/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.mm +++ b/packages/react-native-executorch/ios/RnExecutorch/ETInstaller.mm @@ -21,21 +21,18 @@ @implementation ETInstaller RCT_EXPORT_MODULE(ETInstaller); RCT_EXPORT_BLOCKING_SYNCHRONOUS_METHOD(install) { - auto jsiRuntime = - reinterpret_cast(self.bridge.runtime); + auto jsiRuntime = reinterpret_cast(self.bridge.runtime); auto jsCallInvoker = _callInvoker.callInvoker; assert(jsiRuntime != nullptr); auto fetchUrl = [](std::string url) { @try { - NSString *nsUrlStr = - [NSString stringWithCString:url.c_str() - encoding:[NSString defaultCStringEncoding]]; + NSString *nsUrlStr = [NSString stringWithCString:url.c_str() + encoding:[NSString defaultCStringEncoding]]; NSURL *nsUrl = [NSURL URLWithString:nsUrlStr]; NSData *data = [NSData dataWithContentsOfURL:nsUrl]; - const std::byte *bytePtr = - reinterpret_cast(data.bytes); + const std::byte *bytePtr = reinterpret_cast(data.bytes); int bufferLength = [data length]; return std::vector(bytePtr, bytePtr + bufferLength); } @catch (NSException *exception) { @@ -43,8 +40,8 @@ @implementation ETInstaller } }; bool isEmulator = TARGET_OS_SIMULATOR; - rnexecutorch::RnExecutorchInstaller::injectJSIBindings( - jsiRuntime, jsCallInvoker, fetchUrl, isEmulator); + rnexecutorch::RnExecutorchInstaller::injectJSIBindings(jsiRuntime, jsCallInvoker, fetchUrl, + isEmulator); NSLog(@"Successfully installed JSI bindings for react-native-executorch!"); return @true;