diff --git a/apps/text-embeddings/app/text-embeddings/index.tsx b/apps/text-embeddings/app/text-embeddings/index.tsx index 88e39ce063..2c62a22922 100644 --- a/apps/text-embeddings/app/text-embeddings/index.tsx +++ b/apps/text-embeddings/app/text-embeddings/index.tsx @@ -5,7 +5,6 @@ import { TextInput, TouchableOpacity, View, - SafeAreaView, ScrollView, KeyboardAvoidingView, Platform, @@ -16,10 +15,17 @@ import { models, useTextEmbeddings, TextEmbeddingsProps, + EmbeddingResult, } from 'react-native-executorch'; +import { useIsFocused } from 'expo-router'; +import { dotProduct, maxSim } from '../../utils/math'; +import ErrorBanner from '../../components/ErrorBanner'; +import { SafeAreaView } from 'react-native-safe-area-context'; + const textEmbedding = models.text_embedding; type TextEmbeddingModel = TextEmbeddingsProps['model']; +type Encoding = Float32Array | EmbeddingResult; const MODELS: { label: string; value: TextEmbeddingModel }[] = [ { label: 'MiniLM L6', value: textEmbedding.all_minilm_l6_v2() }, @@ -43,10 +49,42 @@ const MODELS: { label: string; value: TextEmbeddingModel }[] = [ label: 'Multilingual Paraphrase', value: textEmbedding.paraphrase_multilingual_minilm_l12_v2(), }, + { + label: 'LFM2.5 Embedding XNNPACK', + value: textEmbedding.lfm2_5_embedding_350m({ backend: 'xnnpack' }), + }, + { + label: 'LFM2.5 Embedding MLX', + value: textEmbedding.lfm2_5_embedding_350m({ backend: 'mlx' }), + }, + { + label: 'LFM2.5 ColBERT (late-interaction)', + value: textEmbedding.lfm2_5_colbert_350m(), + }, +]; + +const CORPUS: string[] = [ + 'The forecast says heavy showers this afternoon.', + "It's so sunny outside today!", + 'A thick fog rolled in over the harbor at dawn.', + 'The home team scored in the final minute to win the match.', + 'She sprinted the last lap and broke the national record.', + 'Fans packed the stadium for the championship game.', + 'Simmer the tomatoes with garlic before adding the pasta.', + 'He whisked the eggs and folded in the melted chocolate.', + 'The new phone has a faster chip and a brighter screen.', + 'Our servers crashed under the sudden spike in traffic.', + 'The flight to Tokyo was delayed by three hours.', + 'We hiked along the coast and camped near the cliffs.', +]; + +const EXAMPLE_QUERIES: string[] = [ + "What's the weather like?", + 'Who won the match?', + 'Tell me about the latest technology', + 'How do I cook dinner?', + 'Where did they travel?', ]; -import { useIsFocused } from 'expo-router'; -import { dotProduct } from '../../utils/math'; -import ErrorBanner from '../../components/ErrorBanner'; export default function TextEmbeddingsScreenWrapper() { const isFocused = useIsFocused(); @@ -54,6 +92,8 @@ export default function TextEmbeddingsScreenWrapper() { return isFocused ? : null; } +type RankedResult = { sentence: string; similarity: number }; + function TextEmbeddingsScreen() { const [selectedModel, setSelectedModel] = useState( textEmbedding.all_minilm_l6_v2() @@ -61,88 +101,70 @@ function TextEmbeddingsScreen() { const model = useTextEmbeddings({ model: selectedModel }); const [error, setError] = useState(null); - const [inputSentence, setInputSentence] = useState(''); - const [sentencesWithEmbeddings, setSentencesWithEmbeddings] = useState< - { sentence: string; embedding: Float32Array }[] - >([]); - const [topMatches, setTopMatches] = useState< - { sentence: string; similarity: number }[] + const isMultiVector = !!selectedModel.multiVector; + const skipListIds = selectedModel.skipListIds ?? []; + + const [query, setQuery] = useState(''); + const [corpusEmbeddings, setCorpusEmbeddings] = useState< + { sentence: string; embedding: Encoding }[] >([]); + const [results, setResults] = useState([]); const [embeddingTime, setEmbeddingTime] = useState(null); + const [indexing, setIndexing] = useState(false); useEffect( () => { - const computeEmbeddings = async () => { + let cancelled = false; + const indexCorpus = async () => { if (!model.isReady) return; - - const sentences = [ - 'The weather is lovely today.', - "It's so sunny outside!", - 'He drove to the stadium.', - ]; - + setIndexing(true); + setResults([]); try { - const embeddings = []; - for (const sentence of sentences) { - const embedding = await model.forward(sentence); - embeddings.push({ sentence, embedding }); + const embedded = []; + for (const sentence of CORPUS) { + const embedding = await model.forward(sentence, 'document'); + if (cancelled) return; + embedded.push({ sentence, embedding }); } - - setSentencesWithEmbeddings(embeddings); - } catch (e) { - setError(e instanceof Error ? e.message : String(e)); + setCorpusEmbeddings(embedded); + } finally { + if (!cancelled) setIndexing(false); } }; - - computeEmbeddings(); + indexCorpus(); + return () => { + cancelled = true; + }; }, + // eslint-disable-next-line react-hooks/exhaustive-deps - [model.isReady] + [model.isReady, selectedModel] ); - const checkSimilarities = async () => { - if (!model.isReady || !inputSentence.trim()) return; - + const runSearch = async (queryText: string = query) => { + const q = queryText.trim(); + if (!model.isReady || !q || corpusEmbeddings.length === 0) return; + setQuery(queryText); try { const start = Date.now(); - const inputEmbedding = await model.forward(inputSentence); + const queryEmbedding = (await model.forward(q, 'query')) as Encoding; setEmbeddingTime(Date.now() - start); - const matches = sentencesWithEmbeddings.map( - ({ sentence, embedding }) => ({ + const ranked = corpusEmbeddings + .map(({ sentence, embedding }) => ({ sentence, - similarity: dotProduct(inputEmbedding, embedding), - }) - ); - matches.sort((a, b) => b.similarity - a.similarity); - setTopMatches(matches.slice(0, 3)); - } catch (e) { - setError(e instanceof Error ? e.message : String(e)); - } - }; - - const addToSentences = async () => { - if (!model.isReady || !inputSentence.trim()) return; - - try { - const start = Date.now(); - const embedding = await model.forward(inputSentence); - setEmbeddingTime(Date.now() - start); - setSentencesWithEmbeddings((prev) => [ - ...prev, - { sentence: inputSentence, embedding }, - ]); - } catch (e) { - setError(e instanceof Error ? e.message : String(e)); - } - - setInputSentence(''); - setTopMatches([]); - }; - - const clearList = async () => { - if (!model.isReady) return; - try { - setSentencesWithEmbeddings([]); + similarity: isMultiVector + ? maxSim( + queryEmbedding as EmbeddingResult, + embedding as EmbeddingResult, + skipListIds + ) + : dotProduct( + queryEmbedding as Float32Array, + embedding as Float32Array + ), + })) + .sort((a, b) => b.similarity - a.similarity); + setResults(ranked); } catch (e) { setError(e instanceof Error ? e.message : String(e)); } @@ -158,6 +180,9 @@ function TextEmbeddingsScreen() { return model.isGenerating ? 'Generating...' : 'Model is ready'; }; + const ready = model.isReady && !indexing && corpusEmbeddings.length > 0; + const canSearch = ready && !!query.trim(); + return ( - Text Embeddings Playground + Semantic Search {getModelStatusText()} { setSelectedModel(m); - setSentencesWithEmbeddings([]); - setTopMatches([]); + setCorpusEmbeddings([]); + setResults([]); + setQuery(''); }} /> setError(null)} /> - List of Existing Sentences - {sentencesWithEmbeddings.map((item, index) => ( - - - {item.sentence} - - ))} - - - Try Your Sentence + + Search the corpus ({CORPUS.length} sentences) + + + {isMultiVector + ? 'Ranks per-token vectors with MaxSim (late interaction). Ask a full question — tap an example or type your own.' + : 'Ranks every sentence by meaning. Ask a full question — tap an example or type your own.'} + + + {EXAMPLE_QUERIES.map((q) => ( + runSearch(q)} + > + {q} + + ))} + runSearch()} + returnKeyType="search" /> - - runSearch()} + style={[ + styles.buttonPrimary, + !canSearch && styles.buttonDisabled, + ]} + disabled={!canSearch} + > + + - - - Find Similar - - - - - - - Add to List - - - - - - Clear List - - - - + {indexing ? 'Indexing corpus…' : 'Search'} + + {embeddingTime !== null && ( - Embedding time: {embeddingTime} ms + Query embedded in {embeddingTime} ms )} - {topMatches.length > 0 && ( - - Top Matches - {topMatches.map((item, index) => ( - - {item.sentence} ({item.similarity.toFixed(2)}) - - ))} - - )} + + {results.length > 0 && ( + + Results + {results.map((item, index) => ( + + ))} + + )} ); } +function ResultRow({ + sentence, + similarity, + best, + rank, +}: { + sentence: string; + similarity: number; + best: number; + rank: number; +}) { + const fraction = best > 0 ? Math.max(0, similarity / best) : 0; + return ( + + + {sentence} + {similarity.toFixed(2)} + + + + + + ); +} + const styles = StyleSheet.create({ container: { flex: 1, @@ -323,11 +345,68 @@ const styles = StyleSheet.create({ marginBottom: 12, color: '#1E293B', }, - sentenceText: { - fontSize: 14, + hint: { + fontSize: 13, + color: '#64748B', + marginBottom: 12, + lineHeight: 18, + }, + chipRow: { + flexDirection: 'row', + flexWrap: 'wrap', + gap: 8, + marginBottom: 12, + }, + chip: { + backgroundColor: '#EEF2FF', + borderColor: '#C7D2FE', + borderWidth: 1, + borderRadius: 16, + paddingHorizontal: 12, + paddingVertical: 6, + }, + chipDisabled: { + opacity: 0.4, + }, + chipText: { + fontSize: 13, + color: 'navy', + }, + resultRow: { + marginBottom: 14, + }, + resultHeader: { + flexDirection: 'row', + justifyContent: 'space-between', + alignItems: 'flex-start', marginBottom: 6, + gap: 8, + }, + resultText: { + flex: 1, + fontSize: 14, color: '#334155', }, + resultScore: { + fontSize: 14, + fontWeight: '600', + color: '#0F172A', + fontVariant: ['tabular-nums'], + }, + barTrack: { + height: 8, + borderRadius: 4, + backgroundColor: '#E2E8F0', + overflow: 'hidden', + }, + barFill: { + height: '100%', + borderRadius: 4, + backgroundColor: '#94A3B8', + }, + barFillTop: { + backgroundColor: 'navy', + }, input: { backgroundColor: '#F1F5F9', borderRadius: 10, @@ -338,17 +417,8 @@ const styles = StyleSheet.create({ minHeight: 40, textAlignVertical: 'top', }, - buttonContainer: { - width: '100%', - gap: 10, - }, - buttonGroup: { - flexDirection: 'row', - justifyContent: 'space-between', - gap: 10, - }, buttonPrimary: { - flex: 1, + width: '100%', backgroundColor: 'navy', padding: 12, borderRadius: 10, @@ -356,17 +426,6 @@ const styles = StyleSheet.create({ alignItems: 'center', justifyContent: 'center', }, - buttonSecondary: { - flex: 1, - backgroundColor: 'transparent', - borderWidth: 2, - borderColor: 'navy', - padding: 12, - borderRadius: 10, - flexDirection: 'row', - alignItems: 'center', - justifyContent: 'center', - }, buttonDisabled: { backgroundColor: '#f0f0f0', borderColor: '#d3d3d3', @@ -376,17 +435,9 @@ const styles = StyleSheet.create({ textAlign: 'center', fontWeight: '500', }, - buttonTextOutline: { - color: 'navy', - textAlign: 'center', - fontWeight: '500', - }, buttonTextDisabled: { color: 'gray', }, - topMatchesContainer: { - marginTop: 20, - }, statsText: { fontSize: 13, color: '#64748B', diff --git a/apps/text-embeddings/utils/math.ts b/apps/text-embeddings/utils/math.ts index 50c70d1f92..44248e1658 100644 --- a/apps/text-embeddings/utils/math.ts +++ b/apps/text-embeddings/utils/math.ts @@ -1,6 +1,7 @@ import { RnExecutorchError, RnExecutorchErrorCode, + EmbeddingResult, } from 'react-native-executorch'; export const dotProduct = (a: Float32Array, b: Float32Array) => { @@ -17,3 +18,28 @@ export const dotProduct = (a: Float32Array, b: Float32Array) => { } return sum; }; + +export const maxSim = ( + query: EmbeddingResult, + doc: EmbeddingResult, + skipListIds: number[] = [] +) => { + const dim = query.embeddingDim; + const skip = new Set(skipListIds); + let score = 0; + for (let qi = 0; qi < query.numTokens; qi++) { + const qOff = qi * dim; + let best = -Infinity; + for (let di = 0; di < doc.numTokens; di++) { + if (skip.has(doc.tokenIds[di]!)) continue; + const dOff = di * dim; + let dot = 0; + for (let k = 0; k < dim; k++) { + dot += (query.vectors[qOff + k] ?? 0) * (doc.vectors[dOff + k] ?? 0); + } + if (dot > best) best = dot; + } + if (best !== -Infinity) score += best; + } + return score; +}; diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp index 76e0fb90c7..dfd9243c48 100644 --- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp @@ -26,17 +26,15 @@ TokenizerModule::TokenizerModule( memorySizeLowerBound = std::filesystem::file_size(modelPath); } -std::vector TokenizerModule::encode(std::string s) const { +// When the tokenizer.json defines a post_processor, the underlying HFTokenizer +// treats non-zero bos/eos as a flag to run it with add_special_token=true (not +// a literal count). So bos=eos=0 skips special tokens; bos=eos=1 applies them. +std::vector TokenizerModule::encodeImpl(const std::string &s, + int8_t bos, int8_t eos) const { if (!tokenizer) { THROW_NOT_LOADED_ERROR(); } - - // If the used tokenizer.json has defined post_processor field, - // setting any of bos or eos arguments to value other than provided constant - // ( which is 0) will result in running the post_processor with - // 'add_special_token' flag - auto encodeResult = - tokenizer->encode(s, numOfAddedBoSTokens, numOfAddedEoSTokens); + auto encodeResult = tokenizer->encode(s, bos, eos); if (!encodeResult.ok()) { throw RnExecutorchError( RnExecutorchErrorCode::TokenizerError, @@ -46,6 +44,15 @@ std::vector TokenizerModule::encode(std::string s) const { return encodeResult.get(); } +std::vector TokenizerModule::encode(std::string s) const { + return encodeImpl(s, numOfAddedBoSTokens, numOfAddedEoSTokens); +} + +std::vector +TokenizerModule::encodeWithSpecialTokens(std::string s) const { + return encodeImpl(s, /*bos=*/1, /*eos=*/1); +} + std::string TokenizerModule::decode(std::vector vec, bool skipSpecialTokens) const { if (!tokenizer) { diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h index 3c90b25557..0e1356f121 100644 --- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h +++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h @@ -13,6 +13,8 @@ class TokenizerModule { std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector encode(std::string s) const; + [[nodiscard("Registered non-void function")]] std::vector + encodeWithSpecialTokens(std::string s) const; [[nodiscard("Registered non-void function")]] std::string decode(std::vector vec, bool skipSpecialTokens) const; [[nodiscard("Registered non-void function")]] std::string @@ -24,6 +26,9 @@ class TokenizerModule { std::size_t getMemoryLowerBound() const noexcept; private: + std::vector encodeImpl(const std::string &s, int8_t bos, + int8_t eos) const; + std::unique_ptr tokenizer; std::size_t memorySizeLowerBound{0}; }; diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index e4209b2f79..fdc87cd9af 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -707,6 +708,30 @@ getJsiValue(const models::style_transfer::PixelDataResult &result, return obj; } +inline jsi::Value getJsiValue(const models::embeddings::EmbeddingResult &result, + jsi::Runtime &runtime) { + jsi::Object obj(runtime); + + auto arrayBuffer = jsi::ArrayBuffer(runtime, result.dataPtr); + auto float32ArrayCtor = + runtime.global().getPropertyAsFunction(runtime, "Float32Array"); + auto float32Array = float32ArrayCtor.callAsConstructor(runtime, arrayBuffer) + .getObject(runtime); + obj.setProperty(runtime, "dataPtr", float32Array); + + obj.setProperty(runtime, "numTokens", jsi::Value(result.numTokens)); + obj.setProperty(runtime, "embeddingDim", jsi::Value(result.embeddingDim)); + + auto idsArray = jsi::Array(runtime, result.tokenIds.size()); + for (size_t i = 0; i < result.tokenIds.size(); ++i) { + idsArray.setValueAtIndex( + runtime, i, jsi::Value(static_cast(result.tokenIds[i]))); + } + obj.setProperty(runtime, "tokenIds", idsArray); + + return obj; +} + inline jsi::Value getJsiValue( const rnexecutorch::models::semantic_segmentation::SegmentationResult &result, diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp deleted file mode 100644 index bf291136c1..0000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "BaseEmbeddings.h" - -#include - -namespace rnexecutorch::models::embeddings { - -BaseEmbeddings::BaseEmbeddings(const std::string &modelSource, - std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) {} - -std::shared_ptr -BaseEmbeddings::postprocess(const Result> &forwardResult) { - auto forwardResultTensor = forwardResult->at(0).toTensor(); - auto buffer = std::make_shared( - forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes()); - return buffer; -} - -} // namespace rnexecutorch::models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h deleted file mode 100644 index 216d6bf8ce..0000000000 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include - -namespace rnexecutorch::models::embeddings { - -class BaseEmbeddings : public BaseModel { -public: - BaseEmbeddings(const std::string &modelSource, - std::shared_ptr callInvoker); - -protected: - std::shared_ptr - postprocess(const Result> &forwardResult); -}; - -}; // namespace rnexecutorch::models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h new file mode 100644 index 0000000000..f2de1e899a --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include +#include + +namespace rnexecutorch::models::embeddings { + +// Text embedding output as a [numTokens, embeddingDim] fp32 matrix. Pooled +// single-vector models output numTokens == 1 (the exported graph pools + L2- +// normalizes); multi-vector (late-interaction / ColBERT) models output +// numTokens == sequence length. The TS layer reduces to a single vector or +// keeps the per-token matrix based on the model's config. `tokenIds` are the +// input ids (used JS-side for late-interaction skiplist masking). +struct EmbeddingResult { + std::shared_ptr dataPtr; + int32_t numTokens; + int32_t embeddingDim; + std::vector tokenIds; +}; + +} // namespace rnexecutorch::models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp index ba2c3243b2..6e5982c2a5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp @@ -11,12 +11,15 @@ using namespace executorch::extension; TextEmbeddings::TextEmbeddings(const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker) - : BaseEmbeddings(modelSource, callInvoker), + : BaseModel(modelSource, callInvoker), tokenizer( std::make_unique(tokenizerSource, callInvoker)) {} TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) { - auto inputIds = tokenizer->encode(input); + // Apply the tokenizer's post_processor so declared special tokens (e.g. a + // BOS prepended via TemplateProcessing) are added. CLS-pooled embedding + // models read position 0, so a missing BOS corrupts the pooled vector. + auto inputIds = tokenizer->encodeWithSpecialTokens(input); // Tokenizers-cpp return tokens as int32, but text embedding models require // int64 as input std::vector inputIds64; @@ -40,8 +43,7 @@ void TextEmbeddings::unload() noexcept { BaseModel::unload(); } -std::shared_ptr -TextEmbeddings::generate(const std::string input) { +EmbeddingResult TextEmbeddings::generate(const std::string input) { std::scoped_lock lock(inference_mutex_); auto preprocessed = preprocess(input); @@ -58,7 +60,41 @@ TextEmbeddings::generate(const std::string input) { auto forwardResult = BaseModel::forward({tokenIds, attnMask}); CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult); - return BaseEmbeddings::postprocess(forwardResult); + return buildResult(forwardResult->at(0).toTensor(), + std::move(preprocessed.inputIds)); +} + +// Output is [1, numTokens, embeddingDim] (numTokens == 1 for pooled models, +// == sequence length for multi-vector models). Multi-vector consumers index +// tokenIds[i] per output row (e.g. skiplist masking), so numTokens must match +// the input token count or that alignment silently breaks. +EmbeddingResult +TextEmbeddings::buildResult(const executorch::aten::Tensor &output, + std::vector tokenIds) { + auto sizes = output.sizes(); + if (sizes.size() < 2) { + throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput, + "Embedding output must be at least 2D, got rank " + + std::to_string(sizes.size())); + } + + const auto numTokens = static_cast(sizes[sizes.size() - 2]); + const auto inputTokens = static_cast(tokenIds.size()); + if (numTokens != 1 && numTokens != inputTokens) { + throw RnExecutorchError( + RnExecutorchErrorCode::InvalidModelOutput, + "Embedding output rows (" + std::to_string(numTokens) + + ") != input tokens (" + std::to_string(inputTokens) + + "); per-token tokenIds alignment is broken."); + } + + return EmbeddingResult{ + .dataPtr = std::make_shared(output.const_data_ptr(), + output.nbytes()), + .numTokens = numTokens, + .embeddingDim = static_cast(sizes[sizes.size() - 1]), + .tokenIds = std::move(tokenIds), + }; } } // namespace rnexecutorch::models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h index 93d0988c04..02cfefde4d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h @@ -3,7 +3,8 @@ #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" #include #include -#include +#include +#include namespace rnexecutorch { namespace models::embeddings { @@ -13,13 +14,16 @@ struct TokenIdsWithAttentionMask { std::vector attentionMask; }; -class TextEmbeddings final : public BaseEmbeddings { +class TextEmbeddings final : public BaseModel { public: TextEmbeddings(const std::string &modelSource, const std::string &tokenizerSource, std::shared_ptr callInvoker); - [[nodiscard( - "Registered non-void function")]] std::shared_ptr + // Returns the raw [numTokens, embeddingDim] output. Pooled models give + // numTokens == 1; multi-vector (late-interaction) models give the full + // sequence. The TS layer reduces to a single vector or keeps the matrix + // based on the model's config. + [[nodiscard("Registered non-void function")]] EmbeddingResult generate(const std::string input); void unload() noexcept; @@ -27,6 +31,8 @@ class TextEmbeddings final : public BaseEmbeddings { mutable std::mutex inference_mutex_; std::vector> inputShapes; TokenIdsWithAttentionMask preprocess(const std::string &input); + static EmbeddingResult buildResult(const executorch::aten::Tensor &output, + std::vector tokenIds); std::unique_ptr tokenizer; }; } // namespace models::embeddings diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp index 68a9a9fef4..3bf5fa2206 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp @@ -16,9 +16,10 @@ Encoder::Encoder(const std::string &tokenizerSource, encoderSource, tokenizerSource, callInvoker)) {} std::vector Encoder::generate(std::string input) { - std::shared_ptr embeddingsText = encoder->generate(input); + std::shared_ptr embeddingsText = + encoder->generate(input).dataPtr; std::shared_ptr embeddingsUncond = - encoder->generate(std::string(constants::kBosToken)); + encoder->generate(std::string(constants::kBosToken)).dataPtr; assert(embeddingsText->size() == embeddingsUncond->size()); size_t embeddingsSize = embeddingsText->size() / sizeof(float); diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt index 5f9d7287a5..a901cd56fc 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt +++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt @@ -218,7 +218,6 @@ add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/embeddings/image/ImageEmbeddings.cpp - ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp ${RNEXECUTORCH_DIR}/models/VisionModel.cpp ${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp ${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp @@ -230,7 +229,6 @@ add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp SOURCES ${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp - ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp ${TOKENIZER_SOURCES} LIBS tokenizers_deps ) @@ -306,7 +304,6 @@ add_rn_test(TextToImageTests integration/TextToImageTest.cpp ${RNEXECUTORCH_DIR}/models/text_to_image/Decoder.cpp ${RNEXECUTORCH_DIR}/models/text_to_image/Scheduler.cpp ${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp - ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp ${TOKENIZER_SOURCES} LIBS tokenizers_deps ) diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp index ff1abd4c30..0e0cc846b5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp @@ -53,23 +53,23 @@ TEST(TextEmbeddingsGenerateTests, EmptyStringReturnsResults) { TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate(""); - EXPECT_NE(result, nullptr); - EXPECT_GT(result->size(), 0u); + EXPECT_NE(result.dataPtr, nullptr); + EXPECT_GT(result.dataPtr->size(), 0u); } TEST(TextEmbeddingsGenerateTests, ValidTextReturnsResults) { TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("Hello, world!"); - EXPECT_NE(result, nullptr); - EXPECT_GT(result->size(), 0u); + EXPECT_NE(result.dataPtr, nullptr); + EXPECT_GT(result.dataPtr->size(), 0u); } TEST(TextEmbeddingsGenerateTests, ResultsHaveCorrectSize) { TextEmbeddings model(kValidTextEmbeddingsModelPath, kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("This is a test sentence."); - size_t numFloats = result->size() / sizeof(float); + size_t numFloats = result.dataPtr->size() / sizeof(float); EXPECT_EQ(numFloats, kMiniLmEmbeddingDimensions); } @@ -78,8 +78,8 @@ TEST(TextEmbeddingsGenerateTests, ResultsAreNormalized) { kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("The quick brown fox jumps over the lazy dog."); - const float *data = reinterpret_cast(result->data()); - size_t numFloats = result->size() / sizeof(float); + const float *data = reinterpret_cast(result.dataPtr->data()); + size_t numFloats = result.dataPtr->size() / sizeof(float); float sumOfSquares = 0.0f; for (size_t i = 0; i < numFloats; ++i) { @@ -94,8 +94,8 @@ TEST(TextEmbeddingsGenerateTests, ResultsContainValidValues) { kValidTextEmbeddingsTokenizerPath, nullptr); auto result = model.generate("Testing valid values."); - const float *data = reinterpret_cast(result->data()); - size_t numFloats = result->size() / sizeof(float); + const float *data = reinterpret_cast(result.dataPtr->data()); + size_t numFloats = result.dataPtr->size() / sizeof(float); for (size_t i = 0; i < numFloats; ++i) { EXPECT_FALSE(std::isnan(data[i])); @@ -110,9 +110,9 @@ TEST(TextEmbeddingsGenerateTests, DifferentTextProducesDifferentEmbeddings) { auto result1 = model.generate("Hello, world!"); auto result2 = model.generate("Goodbye, moon!"); - const float *data1 = reinterpret_cast(result1->data()); - const float *data2 = reinterpret_cast(result2->data()); - size_t numFloats = result1->size() / sizeof(float); + const float *data1 = reinterpret_cast(result1.dataPtr->data()); + const float *data2 = reinterpret_cast(result2.dataPtr->data()); + size_t numFloats = result1.dataPtr->size() / sizeof(float); bool allEqual = true; for (size_t i = 0; i < numFloats; ++i) { @@ -131,9 +131,9 @@ TEST(TextEmbeddingsGenerateTests, SimilarTextProducesSimilarEmbeddings) { auto result1 = model.generate("I love programming"); auto result2 = model.generate("I enjoy coding"); - const float *data1 = reinterpret_cast(result1->data()); - const float *data2 = reinterpret_cast(result2->data()); - size_t numFloats = result1->size() / sizeof(float); + const float *data1 = reinterpret_cast(result1.dataPtr->data()); + const float *data2 = reinterpret_cast(result2.dataPtr->data()); + size_t numFloats = result1.dataPtr->size() / sizeof(float); float dotProduct = 0.0f; for (size_t i = 0; i < numFloats; ++i) { diff --git a/packages/react-native-executorch/src/constants/modelRegistry.ts b/packages/react-native-executorch/src/constants/modelRegistry.ts index eb0c98dae7..4c36c6a1fa 100644 --- a/packages/react-native-executorch/src/constants/modelRegistry.ts +++ b/packages/react-native-executorch/src/constants/modelRegistry.ts @@ -260,6 +260,59 @@ const GEMMA4_E2B_MM_VARIANTS = { }, }; +// Asymmetric query/document prompts the LFM models are trained with. +// forward(text, role) auto-prepends these. +const LFM_EMBEDDING_PROMPTS = { query: 'query: ', document: 'document: ' }; +const LFM_COLBERT_PROMPTS = { query: '[Q] ', document: '[D] ' }; + +const LFM2_5_EMBEDDING_350M_VARIANTS = { + mlx: { + base: { + modelName: 'lfm2-5-embedding-350m' as const, + modelSource: M.LFM2_5_EMBEDDING_350M_MLX_MODEL, + tokenizerSource: M.LFM2_5_EMBEDDING_350M_TOKENIZER, + prompts: LFM_EMBEDDING_PROMPTS, + }, + }, + xnnpack: { + base: { + modelName: 'lfm2-5-embedding-350m' as const, + modelSource: M.LFM2_5_EMBEDDING_350M_XNNPACK_MODEL, + tokenizerSource: M.LFM2_5_EMBEDDING_350M_TOKENIZER, + prompts: LFM_EMBEDDING_PROMPTS, + }, + }, +}; + +const LFM_COLBERT_SKIP_LIST = [ + 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, + 535, 536, 537, 538, 539, 540, 541, 568, 569, 570, 571, 572, 573, 600, 601, + 602, 603, +]; + +const LFM2_5_COLBERT_350M_VARIANTS = { + mlx: { + base: { + modelName: 'lfm2-5-colbert-350m' as const, + modelSource: M.LFM2_5_COLBERT_350M_MLX_MODEL, + tokenizerSource: M.LFM2_5_COLBERT_350M_TOKENIZER, + prompts: LFM_COLBERT_PROMPTS, + multiVector: true as const, + skipListIds: LFM_COLBERT_SKIP_LIST, + }, + }, + xnnpack: { + base: { + modelName: 'lfm2-5-colbert-350m' as const, + modelSource: M.LFM2_5_COLBERT_350M_XNNPACK_MODEL, + tokenizerSource: M.LFM2_5_COLBERT_350M_TOKENIZER, + prompts: LFM_COLBERT_PROMPTS, + multiVector: true as const, + skipListIds: LFM_COLBERT_SKIP_LIST, + }, + }, +}; + const EFFICIENTNET_V2_S_VARIANTS = { xnnpack: { base: { @@ -742,6 +795,14 @@ export const models = { M.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_QUANTIZED ), clip_vit_base_patch32_text: base(M.CLIP_VIT_BASE_PATCH32_TEXT), + lfm2_5_embedding_350m: variant(LFM2_5_EMBEDDING_350M_VARIANTS, { + ios: 'mlx', + android: 'xnnpack', + }), + lfm2_5_colbert_350m: variant(LFM2_5_COLBERT_350M_VARIANTS, { + ios: 'mlx', + android: 'xnnpack', + }), }, image_embedding: { clip_vit_base_patch32_image: pair( diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 0e36f812ff..bd6cddf4a3 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -1195,6 +1195,12 @@ export const DISTILUSE_BASE_MULTILINGUAL_CASED_V2_8DA4W_MODEL = `${URL_PREFIX}-d export const DISTILUSE_BASE_MULTILINGUAL_CASED_V2_TOKENIZER = `${URL_PREFIX}-distiluse-base-multilingual-cased-v2/${PREVIOUS_VERSION_TAG}/tokenizer.json`; const PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_QUANTIZED_MODEL = `${URL_PREFIX}-paraphrase-multilingual-MiniLM-L12-v2/${PREVIOUS_VERSION_TAG}/xnnpack/paraphrase_multilingual_minilm_l12_v2_xnnpack_8da4w.pte`; const PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_TOKENIZER = `${URL_PREFIX}-paraphrase-multilingual-MiniLM-L12-v2/${PREVIOUS_VERSION_TAG}/tokenizer.json`; +export const LFM2_5_EMBEDDING_350M_XNNPACK_MODEL = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/xnnpack/lfm_2_5_embedding_350m_xnnpack_8da4w.pte`; +export const LFM2_5_EMBEDDING_350M_MLX_MODEL = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/mlx/lfm_2_5_embedding_350m_mlx_int4.pte`; +export const LFM2_5_EMBEDDING_350M_TOKENIZER = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/tokenizer.json`; +export const LFM2_5_COLBERT_350M_XNNPACK_MODEL = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/xnnpack/lfm_2_5_colbert_350m_xnnpack_8da4w.pte`; +export const LFM2_5_COLBERT_350M_MLX_MODEL = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/mlx/lfm_2_5_colbert_350m_mlx_int4.pte`; +export const LFM2_5_COLBERT_350M_TOKENIZER = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/tokenizer.json`; const CLIP_VIT_BASE_PATCH32_TEXT_MODEL = `${URL_PREFIX}-clip-vit-base-patch32/${PREVIOUS_VERSION_TAG}/xnnpack/clip_vit_base_patch32_text_xnnpack_fp32.pte`; const CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER = `${URL_PREFIX}-clip-vit-base-patch32/${PREVIOUS_VERSION_TAG}/tokenizer.json`; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts index 31ee179925..2f100b8cbb 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts @@ -1,20 +1,25 @@ import { TextEmbeddingsModule } from '../../modules/natural_language_processing/TextEmbeddingsModule'; import { useModuleFactory } from '../useModuleFactory'; import { + EmbeddingRole, + ForwardFn, + TextEmbeddingsModel, TextEmbeddingsType, TextEmbeddingsProps, } from '../../types/textEmbeddings'; /** - * React hook for managing a Text Embeddings model instance. + * React hook for a Text Embeddings model. * @category Hooks - * @param TextEmbeddingsProps - Configuration object containing `model` source and optional `preventLoad` flag. - * @returns Ready to use Text Embeddings model. + * @param TextEmbeddingsProps - `model` source + optional `preventLoad`. + * @returns Ready to use embeddings model. `forward` returns the raw + * [numTokens, embeddingDim] result; use `toVector` for a single vector. + * Models with prompts require a `role` ('query' | 'document') on `forward`. */ -export const useTextEmbeddings = ({ +export const useTextEmbeddings = ({ model, preventLoad = false, -}: TextEmbeddingsProps): TextEmbeddingsType => { +}: TextEmbeddingsProps): TextEmbeddingsType => { const { error, isReady, isGenerating, downloadProgress, runForward } = useModuleFactory({ factory: (config, onProgress) => @@ -24,7 +29,8 @@ export const useTextEmbeddings = ({ preventLoad, }); - const forward = (input: string) => runForward((inst) => inst.forward(input)); + const forward = ((input: string, role?: EmbeddingRole) => + runForward((inst) => inst.forward(input, role))) as ForwardFn; return { error, isReady, isGenerating, downloadProgress, forward }; }; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 1f190d41f5..34cdf97d8d 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -212,6 +212,7 @@ export * from './utils/ResourceFetcher'; export * from './utils/ResourceFetcherUtils'; export * from './utils/BaseResourceFetcherClass'; export * from './utils/llm'; +export * from './utils/textEmbeddings'; export * from './common/Logger'; export * from './utils/llms/context_strategy'; export * from './utils/segmentAnythingPrompts'; diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts index 27b0e59ceb..abb620e981 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts @@ -1,5 +1,11 @@ import { ResourceSource } from '../../types/common'; -import { TextEmbeddingsModelName } from '../../types/textEmbeddings'; +import { + EmbeddingPrompts, + EmbeddingResult, + EmbeddingRole, + TextEmbeddingsModel, + TextEmbeddingsModelName, +} from '../../types/textEmbeddings'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { BaseModule } from '../BaseModule'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; @@ -7,27 +13,35 @@ import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { Logger } from '../../common/Logger'; /** - * Module for generating text embeddings from input text. + * Module for text embeddings. `forward` returns a single pooled `Float32Array` + * for standard models, or the per-token `EmbeddingResult` for `multiVector` + * (late-interaction) models. The native runner always produces the raw + * [numTokens, embeddingDim] matrix; the reduction to a single vector happens + * here so the common single-vector API stays `Float32Array`. * @category Typescript API */ export class TextEmbeddingsModule extends BaseModule { - private constructor(nativeModule: unknown) { + private prompts?: EmbeddingPrompts; + private multiVector: boolean; + + private constructor( + nativeModule: unknown, + prompts: EmbeddingPrompts | undefined, + multiVector: boolean + ) { super(); this.nativeModule = nativeModule; + this.prompts = prompts; + this.multiVector = multiVector; } /** * Creates a text embeddings instance for a built-in model. - * @param namedSources - An object specifying which built-in model to load and where to fetch it from. - * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. - * @returns A Promise resolving to a `TextEmbeddingsModule` instance. + * @param namedSources - The model config (+ optional prompts / multiVector). + * @param onDownloadProgress - Optional download progress callback (0..1). */ static async fromModelName( - namedSources: { - modelName: TextEmbeddingsModelName; - modelSource: ResourceSource; - tokenizerSource: ResourceSource; - }, + namedSources: TextEmbeddingsModel, onDownloadProgress: (progress: number) => void = () => {} ): Promise { try { @@ -41,7 +55,9 @@ export class TextEmbeddingsModule extends BaseModule { throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted); } return new TextEmbeddingsModule( - await global.loadTextEmbeddings(modelPath, tokenizerPath) + await global.loadTextEmbeddings(modelPath, tokenizerPath), + namedSources.prompts, + namedSources.multiVector ?? false ); } catch (error) { Logger.error('Load failed:', error); @@ -50,14 +66,9 @@ export class TextEmbeddingsModule extends BaseModule { } /** - * Creates a text embeddings instance with a user-provided model binary and tokenizer. - * Use this when working with a custom-exported model that is not one of the built-in presets. - * @remarks The native model contract for this method is not formally defined and may change - * between releases. Refer to the native source code for the current expected tensor interface. - * @param modelSource - A fetchable resource pointing to the model binary. - * @param tokenizerSource - A fetchable resource pointing to the tokenizer file. - * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. - * @returns A Promise resolving to a `TextEmbeddingsModule` instance. + * Creates a text embeddings instance from a custom model binary + tokenizer. + * @remarks The native tensor contract is not formally guaranteed across + * releases. */ static fromCustomModel( modelSource: ResourceSource, @@ -75,13 +86,32 @@ export class TextEmbeddingsModule extends BaseModule { } /** - * Executes the model's forward pass to generate an embedding for the provided text. - * @param input - The text string to embed. - * @returns A Promise resolving to a `Float32Array` containing the embedding vector. + * Embed text. Standard models return the single pooled `Float32Array`; + * `multiVector` models return the per-token `EmbeddingResult`. + * @param input - The text to embed. + * @param role - 'query' | 'document'; prepends the model's prompt for that + * role when configured (no-op otherwise). */ - async forward(input: string): Promise { + async forward( + input: string, + role?: EmbeddingRole + ): Promise { if (this.nativeModule == null) throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded); - return new Float32Array(await this.nativeModule.generate(input)); + const prefix = (role && this.prompts?.[role]) || ''; + const res = await this.nativeModule.generate(prefix + input); + // res.dataPtr is already a Float32Array view over the owned native buffer + // (built at the JSI boundary). + const vectors = res.dataPtr as Float32Array; + if (!this.multiVector) { + // Pooled models output [1, embeddingDim]; return that single row. + return vectors.subarray(0, res.embeddingDim); + } + return { + vectors, + numTokens: res.numTokens, + embeddingDim: res.embeddingDim, + tokenIds: res.tokenIds, + }; } } diff --git a/packages/react-native-executorch/src/types/textEmbeddings.ts b/packages/react-native-executorch/src/types/textEmbeddings.ts index d9cd120e26..1b056a1f7b 100644 --- a/packages/react-native-executorch/src/types/textEmbeddings.ts +++ b/packages/react-native-executorch/src/types/textEmbeddings.ts @@ -12,65 +12,124 @@ export type TextEmbeddingsModelName = | 'multi-qa-mpnet-base-dot-v1' | 'distiluse-base-multilingual-cased-v2-8da4w' | 'paraphrase-multilingual-minilm-l12-v2-quantized' - | 'clip-vit-base-patch32-text'; + | 'clip-vit-base-patch32-text' + | 'lfm2-5-embedding-350m' + | 'lfm2-5-colbert-350m'; /** - * Props for the useTextEmbeddings hook. + * Per-token (multi-vector) embedding output for late-interaction models (e.g. + * ColBERT): a [numTokens, embeddingDim] fp32 matrix (row-major) plus the input + * token ids. Standard models return a single pooled `Float32Array` from + * `forward` instead; only `multiVector` models yield this. * @category Types - * @property {object} model - An object containing the model configuration. - * @property {TextEmbeddingsModelName} model.modelName - Unique name identifying the model. - * @property {ResourceSource} model.modelSource - The source of the text embeddings model binary. - * @property {ResourceSource} model.tokenizerSource - The source of the tokenizer JSON file. - * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ -export interface TextEmbeddingsProps { - model: { - /** - * The unique name of the text embeddings model. - */ - modelName: TextEmbeddingsModelName; - /** - * The source of the text embeddings model binary. - */ - modelSource: ResourceSource; - /** - * The source of the tokenizer JSON file. - */ - tokenizerSource: ResourceSource; - }; - preventLoad?: boolean; +export interface EmbeddingResult { + /** Flat [numTokens * embeddingDim] fp32 vectors (row-major). */ + vectors: Float32Array; + /** Number of token rows. */ + numTokens: number; + /** Per-token vector dimension. */ + embeddingDim: number; + /** Input token ids per row. */ + tokenIds: number[]; } /** - * React hook state and methods for managing a Text Embeddings model instance. + * Role for `forward`. Some models are trained with asymmetric query/document + * prompts (e.g. LFM2.5 uses `query: `/`document: `, ColBERT uses `[Q] `/`[D] `). + * Passing a role auto-prepends the model's configured prompt for that role. * @category Types */ -export interface TextEmbeddingsType { - /** - * Contains the error message if the model failed to load or during inference. - */ - error: null | RnExecutorchError; +export type EmbeddingRole = 'query' | 'document'; - /** - * Indicates whether the embeddings model has successfully loaded and is ready for inference. - */ - isReady: boolean; +/** + * Asymmetric prompts a model is trained with. When a model config carries + * these, `forward` REQUIRES a `role` so the matching prompt is always applied + * (forgetting it would silently embed raw text and wreck asymmetric retrieval). + * @category Types + */ +export interface EmbeddingPrompts { + query: string; + document: string; +} +/** + * A text embeddings model config. Two optional flags drive `forward`: + * - `prompts` present -> `forward` REQUIRES a `role` (auto-prepends the prompt) + * - `multiVector` true -> `forward` returns the per-token `EmbeddingResult`; + * otherwise it returns a single pooled `Float32Array`. + * @category Types + */ +export interface TextEmbeddingsModel { + modelName: TextEmbeddingsModelName; + modelSource: ResourceSource; + tokenizerSource: ResourceSource; + prompts?: EmbeddingPrompts; + multiVector?: boolean; /** - * Indicates whether the model is currently generating embeddings. + * Document token ids to exclude from late-interaction scoring (e.g. ColBERT's + * punctuation skipList). Derived from the model's training config, so it's + * shipped here rather than reconstructed by the consumer, who passes it to + * their own MaxSim scoring. */ - isGenerating: boolean; + skipListIds?: number[]; +} - /** - * Tracks the progress of the model download process (value between 0 and 1). - */ +/** + * `forward`'s signature, computed from the model config: + * - return type: `EmbeddingResult` if `multiVector`, else `Float32Array`. + * - role arg: required if the model has `prompts`, else absent. + */ +export type ForwardReturn = M extends { + multiVector: true; +} + ? EmbeddingResult + : Float32Array; + +/** + * `forward`'s signature, computed from the model config: + * - A model that DEFINITELY has prompts -> `role` is REQUIRED. + * - A model that definitely has NO prompts (`prompts?: undefined`) -> no role. + * - Otherwise (prompts optional / unknown, e.g. a heterogeneous model list) -> + * `role` is OPTIONAL. + */ +export type ForwardFn = M extends { + prompts: EmbeddingPrompts; +} + ? (input: string, role: EmbeddingRole) => Promise> + : undefined extends M['prompts'] + ? M['prompts'] extends undefined + ? (input: string) => Promise> + : (input: string, role?: EmbeddingRole) => Promise> + : (input: string) => Promise>; + +/** + * Props for the useTextEmbeddings hook. + * @category Types + */ +export interface TextEmbeddingsProps< + M extends TextEmbeddingsModel = TextEmbeddingsModel, +> { + model: M; + preventLoad?: boolean; +} + +/** + * React hook state and methods for a Text Embeddings model instance. + * @category Types + */ +export interface TextEmbeddingsType< + M extends TextEmbeddingsModel = TextEmbeddingsModel, +> { + error: null | RnExecutorchError; + isReady: boolean; + isGenerating: boolean; downloadProgress: number; /** - * Runs the text embeddings model on the provided input string. - * @param input - The text string to embed. - * @returns A promise resolving to a Float32Array containing the vector embeddings. - * @throws {RnExecutorchError} If the model is not loaded or is currently processing another request. + * Embed text. Standard models return a single pooled `Float32Array`; + * `multiVector` models return the per-token `EmbeddingResult`. Models with + * `prompts` require a `role` ('query' | 'document'). */ - forward(input: string): Promise; + forward: ForwardFn; } diff --git a/packages/react-native-executorch/src/utils/textEmbeddings.ts b/packages/react-native-executorch/src/utils/textEmbeddings.ts new file mode 100644 index 0000000000..e9be7cf774 --- /dev/null +++ b/packages/react-native-executorch/src/utils/textEmbeddings.ts @@ -0,0 +1,37 @@ +import { EmbeddingResult } from '../types/textEmbeddings'; + +/** + * Get the single pooled embedding vector from a result. Convenience for the + * common single-vector case: the exported graph pools + L2-normalizes to a + * [1, embeddingDim] output, so this returns row 0. + * + * For multi-vector (late-interaction) models, prefer the full per-token + * vectors (`getTokenVectors`); row 0 alone is not a meaningful sentence + * embedding there. + * + * @category Utils + */ +export function toVector(result: EmbeddingResult): Float32Array { + return result.vectors.slice(0, result.embeddingDim); +} + +/** + * Split a result's flat `vectors` buffer into per-token rows + * (`numTokens` arrays of length `embeddingDim`). Useful for inspecting or + * storing individual token vectors (e.g. a multi-vector vector DB). + * + * The rows are zero-copy `subarray` VIEWS over `result.vectors` — valid only + * while that buffer is alive and not mutated. Copy them (e.g. `new + * Float32Array(row)`) before storing beyond the result's lifetime. (`toVector` + * by contrast returns an independent copy.) + * + * @category Utils + */ +export function getTokenVectors(result: EmbeddingResult): Float32Array[] { + const { vectors, numTokens, embeddingDim } = result; + const rows: Float32Array[] = []; + for (let i = 0; i < numTokens; i++) { + rows.push(vectors.subarray(i * embeddingDim, (i + 1) * embeddingDim)); + } + return rows; +}