diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index fd8fb360b5..fef84a8a24 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -229,3 +229,4 @@ imgproc c10 probas Probas +Skia diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index f5506bd40c..91b231a771 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -1,5 +1,5 @@ import { Drawer } from 'expo-router/drawer'; -import ColorPalette from '../colors'; +import { ColorPalette } from '../theme'; import React from 'react'; export default function Layout() { diff --git a/apps/computer-vision/app/classification/index.tsx b/apps/computer-vision/app/classification/index.tsx index f02b93d740..0849f9853b 100644 --- a/apps/computer-vision/app/classification/index.tsx +++ b/apps/computer-vision/app/classification/index.tsx @@ -1,20 +1,15 @@ import React, { useState } from 'react'; -import { - View, - Text, - StyleSheet, - Image, - TouchableOpacity, - ScrollView, - ActivityIndicator, - Platform, -} from 'react-native'; -import { Skia } from '@shopify/react-native-skia'; +import { View, Text, StyleSheet, ScrollView, Platform } from 'react-native'; +import { commonStyles, ColorPalette } from '../../theme'; +import { useImage } from '@shopify/react-native-skia'; import { useClassifier, models } from 'react-native-executorch'; -import ScreenWrapper from '../../ScreenWrapper'; -import ColorPalette from '../../colors'; +import ScreenWrapper from '../../components/ScreenWrapper'; import { getImage } from '../../utils'; import { ModelPicker, type ModelOption } from '../../components/ModelPicker'; +import { ImageViewport } from '../../components/ImageViewport'; +import { ModelStatus } from '../../components/ModelStatus'; +import { LatencyIndicator } from '../../components/LatencyIndicator'; +import { Button } from '../../components/Button'; const MODEL_OPTIONS: ModelOption[] = [ { @@ -32,29 +27,16 @@ const MODEL_OPTIONS: ModelOption[] = [ }, ]; -async function loadImageBuffer(uri: string) { - const data = await Skia.Data.fromURI(uri); - const img = Skia.Image.MakeImageFromEncoded(data); - if (!img) { - throw new Error('Failed to decode image using Skia'); - } - return { - data: img.readPixels() as Uint8Array, - width: img.width(), - height: img.height(), - format: 'rgba' as const, - layout: 'hwc' as const, - }; -} - -export default function ClassificationScreen() { +function ClassificationContent() { const [selectedModel, setSelectedModel] = useState(MODEL_OPTIONS[0].value); const [imageUri, setImageUri] = useState(null); - const [loading, setLoading] = useState(false); + const [isProcessing, setIsProcessing] = useState(false); const [results, setResults] = useState<{ label: string; confidence: number }[]>([]); const [latency, setLatency] = useState(null); const [error, setError] = useState(null); + const skiaImage = useImage(imageUri, (err) => setError(err.message || String(err))); + const { isReady, downloadProgress, @@ -64,228 +46,132 @@ export default function ClassificationScreen() { } = useClassifier(selectedModel); const handlePickImage = async (useCamera: boolean) => { - const asset = await getImage(useCamera); - if (asset?.uri) { - setImageUri(asset.uri); - setResults([]); - setLatency(null); - setError(null); + setError(null); + try { + const uri = await getImage(useCamera); + if (uri) { + setImageUri(uri); + setResults([]); + setLatency(null); + } + } catch (e: any) { + setError(e.message || String(e)); } }; const runClassification = async (sync: boolean) => { - if (!imageUri || !classify || !classifyWorklet) return; - if (!sync) setLoading(true); + if (!skiaImage || !classify || !classifyWorklet) return; + if (!sync) setIsProcessing(true); setError(null); try { - const inputBuffer = await loadImageBuffer(imageUri); + const pixels = skiaImage.readPixels(); + if (!pixels) { + throw new Error('Failed to read pixels from image'); + } + if (!(pixels instanceof Uint8Array)) { + throw new Error('Expected Uint8Array from readPixels'); + } + const buffer = { + data: pixels, + width: skiaImage.width(), + height: skiaImage.height(), + format: 'rgba' as const, + layout: 'hwc' as const, + }; const start = Date.now(); const output = sync - ? classifyWorklet(inputBuffer, { topk: 5 }) - : await classify(inputBuffer, { topk: 5 }); + ? classifyWorklet(buffer, { topk: 5 }) + : await classify(buffer, { topk: 5 }); + setLatency(Date.now() - start); setResults(output); } catch (e: any) { setError(e.message || String(e)); } finally { - if (!sync) setLoading(false); + if (!sync) setIsProcessing(false); } }; const activeError = loadError ? String(loadError) : error; return ( - - - Image Classification - - { - setSelectedModel(model); - setResults([]); - setLatency(null); - setError(null); - }} + + + Upload or capture an image to identify objects using a classifier. + + + { + setSelectedModel(model); + setResults([]); + setLatency(null); + setError(null); + }} + /> + + + + handlePickImage(false)} /> + + +