Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions packages/react-native-executorch/cpp/extensions/cv/image_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,4 +614,90 @@ void install_normalize(jsi::Runtime &rt, jsi::Object &module) {

module.setProperty(rt, name, jsi::Function::createFromHostFunction(rt, jsi::PropNameID::forAscii(rt, name), 3, fnBody));
}

void install_applyColormap(jsi::Runtime &rt, jsi::Object &module) {
auto name = "applyColormap";
auto fnBody = [](jsi::Runtime &rt, const jsi::Value &thisVal, const jsi::Value *args, size_t count) -> jsi::Value {
if (count != 3) {
throw jsi::JSError(rt, "Usage: applyColormap(src, dst, colormap)");
}

if (!args[0].isObject() || !args[0].asObject(rt).isHostObject<TensorHostObject>(rt)) {
throw jsi::JSError(rt, "applyColormap: src must be a Tensor");
}
if (!args[1].isObject() || !args[1].asObject(rt).isHostObject<TensorHostObject>(rt)) {
throw jsi::JSError(rt, "applyColormap: dst must be a Tensor");
}

auto src = args[0].asObject(rt).getHostObject<TensorHostObject>(rt);
auto dst = args[1].asObject(rt).getHostObject<TensorHostObject>(rt);

if (src->dtype_ != rnexecutorch::core::types::DType::int32) {
throw jsi::JSError(rt, "applyColormap: src must be int32");
}
if (dst->dtype_ != rnexecutorch::core::types::DType::uint8) {
throw jsi::JSError(rt, "applyColormap: dst must be uint8");
}
if (dst->numel_ != src->numel_ * 4) {
throw jsi::JSError(rt, "applyColormap: dst must have exactly 4 times the number of elements as src (RGBA channels)");
}

if (!args[2].isObject() || !args[2].asObject(rt).isArray(rt)) {
throw jsi::JSError(rt, "applyColormap: colormap must be an array");
}

auto colormapArray = args[2].asObject(rt).asArray(rt);
size_t numColors = colormapArray.size(rt);
std::vector<std::array<uint8_t, 4>> lut(numColors);
for (size_t i = 0; i < numColors; ++i) {
auto colorVal = colormapArray.getValueAtIndex(rt, i);
if (!colorVal.isObject() || !colorVal.asObject(rt).isArray(rt)) {
throw jsi::JSError(rt, "applyColormap: colormap entry must be an array");
}
auto color = colorVal.asObject(rt).asArray(rt);
if (color.size(rt) != 4) {
throw jsi::JSError(rt, "applyColormap: colormap entry must be an RGBA color array of size 4");
}
for (size_t c = 0; c < 4; ++c) {
auto channelVal = color.getValueAtIndex(rt, c);
if (!channelVal.isNumber()) {
throw jsi::JSError(rt, "applyColormap: colormap channel value must be a number");
}
lut[i][c] = static_cast<uint8_t>(channelVal.asNumber());
}
}

std::shared_lock<std::shared_mutex> srcLock(src->mutex_, std::try_to_lock);
std::unique_lock<std::shared_mutex> dstLock(dst->mutex_, std::try_to_lock);
if (!srcLock.owns_lock() || !dstLock.owns_lock()) {
throw jsi::JSError(rt, "applyColormap: tensors in use");
}

if (!src->data_ || !dst->data_) {
throw jsi::JSError(rt, "applyColormap: tensor has been disposed");
}

size_t pixels = src->numel_;

const int32_t *srcData = reinterpret_cast<const int32_t *>(src->data_.get());
uint8_t *dstData = dst->data_.get();

for (size_t i = 0; i < pixels; ++i) {
int32_t idx = srcData[i];
if (idx < 0 || static_cast<size_t>(idx) >= numColors) {
throw jsi::JSError(rt, "applyColormap: tensor contains class index (" +
std::to_string(idx) + ") that exceeds provided colormap size (" +
std::to_string(numColors) + ")");
}

dstData[i * 4 + 0] = lut[idx][0];
dstData[i * 4 + 1] = lut[idx][1];
dstData[i * 4 + 2] = lut[idx][2];
dstData[i * 4 + 3] = lut[idx][3];
}

return jsi::Value(rt, args[1]);
};
module.setProperty(rt, name, jsi::Function::createFromHostFunction(rt, jsi::PropNameID::forAscii(rt, name), 3, fnBody));
}
} // namespace rnexecutorch::extensions::cv::image_ops
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ void install_cvtColor(facebook::jsi::Runtime &rt, facebook::jsi::Object &module)
void install_toChannelsFirst(facebook::jsi::Runtime &rt, facebook::jsi::Object &module);
void install_toChannelsLast(facebook::jsi::Runtime &rt, facebook::jsi::Object &module);
void install_normalize(facebook::jsi::Runtime &rt, facebook::jsi::Object &module);
void install_applyColormap(facebook::jsi::Runtime &rt, facebook::jsi::Object &module);
} // namespace rnexecutorch::extensions::cv::image_ops
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ void install(facebook::jsi::Runtime &rt, facebook::jsi::Object &module) {
image_ops::install_toChannelsFirst(rt, cvModule);
image_ops::install_toChannelsLast(rt, cvModule);
image_ops::install_normalize(rt, cvModule);
image_ops::install_applyColormap(rt, cvModule);

module.setProperty(rt, "cv", cvModule);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,26 +277,19 @@ void install_argmax(jsi::Runtime &rt, jsi::Object &module) {
}

int32_t *dstData = reinterpret_cast<int32_t *>(dst->data_.get());
std::vector<float> maxVals(inner);

for (size_t o = 0; o < outer; ++o) {
const float *srcSlab = srcData + o * axisDim * inner;
int32_t *dstRow = dstData + o * inner;

for (size_t i = 0; i < inner; ++i) {
maxVals[i] = -std::numeric_limits<float>::infinity();
dstRow[i] = 0;
}

for (size_t d = 0; d < axisDim; ++d) {
const float *srcRow = srcSlab + d * inner;
for (size_t i = 0; i < inner; ++i) {
const float val = srcRow[i];
if (val > maxVals[i]) {
maxVals[i] = val;
dstRow[i] = static_cast<int32_t>(d);
float maxVal = -std::numeric_limits<float>::infinity();
int32_t maxIdx = 0;
for (size_t d = 0; d < axisDim; ++d) {
const float val = srcData[o * axisDim * inner + d * inner + i];
if (val > maxVal) {
maxVal = val;
maxIdx = static_cast<int32_t>(d);
}
}
dstData[o * inner + i] = maxIdx;
}
}

Expand Down
34 changes: 34 additions & 0 deletions packages/react-native-executorch/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,42 @@ export const IMAGENET1K_LABELS = [
'toilet tissue, toilet paper, bathroom tissue',
] as const;

/**
* Pascal VOC dataset label array containing the 21 categories.
* @category Constants
*/
export const PASCAL_VOC_LABELS = [
'background',
'aeroplane',
'bicycle',
'bird',
'boat',
'bottle',
'bus',
'car',
'cat',
'chair',
'cow',
'diningtable',
'dog',
'horse',
'motorbike',
'person',
'pottedplant',
'sheep',
'sofa',
'train',
'tvmonitor',
] as const;

/**
* Type representing a valid ImageNet 1K label string.
* @category Types
*/
export type ImageNet1KLabel = (typeof IMAGENET1K_LABELS)[number];

/**
* Type representing a valid Pascal VOC label string.
* @category Types
*/
export type PascalVocLabel = (typeof PASCAL_VOC_LABELS)[number];
26 changes: 26 additions & 0 deletions packages/react-native-executorch/src/extensions/cv/ops/image.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,29 @@ export function normalize(src: Tensor, dst: Tensor, opts?: NormalizeOptions): Te
} as const;
return rnexecutorchJsi.cv.normalize(src, dst, { ...defaultNormalizeOptions, ...opts });
}

/**
* Applies a colormap to a single-channel image tensor, mapping class indices to
* RGBA colors.
*
* This operation iterates over each index/class ID in the source tensor, looks
* up its corresponding RGBA color in the provided colormap palette, and writes
* it to the destination tensor.
* @category Typescript API
* @param src The source index/mask tensor. Must be an integer tensor of `int32`
* dtype containing class indices. Shape `[H, W, 1]` (or `[H, W]`).
* @param dst The pre-allocated destination tensor to write the mapped RGBA
* values to. Must be a 3D image tensor in HWC layout and `uint8` dtype. Shape
* `[H, W, 4]`.
* @param colormap An array of RGBA color arrays `[R, G, B, A]` corresponding to each
* class index. The size of this list must cover all class indices present in `src`.
* @returns The destination tensor with the applied colormap.
*/
export function applyColormap(
src: Tensor,
dst: Tensor,
colormap: [number, number, number, number][]
): Tensor {
'worklet';
return rnexecutorchJsi.cv.applyColormap(src, dst, colormap);
}
Loading