diff --git a/apps/text-embeddings/app/text-embeddings/index.tsx b/apps/text-embeddings/app/text-embeddings/index.tsx
index 88e39ce063..2c62a22922 100644
--- a/apps/text-embeddings/app/text-embeddings/index.tsx
+++ b/apps/text-embeddings/app/text-embeddings/index.tsx
@@ -5,7 +5,6 @@ import {
TextInput,
TouchableOpacity,
View,
- SafeAreaView,
ScrollView,
KeyboardAvoidingView,
Platform,
@@ -16,10 +15,17 @@ import {
models,
useTextEmbeddings,
TextEmbeddingsProps,
+ EmbeddingResult,
} from 'react-native-executorch';
+import { useIsFocused } from 'expo-router';
+import { dotProduct, maxSim } from '../../utils/math';
+import ErrorBanner from '../../components/ErrorBanner';
+import { SafeAreaView } from 'react-native-safe-area-context';
+
const textEmbedding = models.text_embedding;
type TextEmbeddingModel = TextEmbeddingsProps['model'];
+type Encoding = Float32Array | EmbeddingResult;
const MODELS: { label: string; value: TextEmbeddingModel }[] = [
{ label: 'MiniLM L6', value: textEmbedding.all_minilm_l6_v2() },
@@ -43,10 +49,42 @@ const MODELS: { label: string; value: TextEmbeddingModel }[] = [
label: 'Multilingual Paraphrase',
value: textEmbedding.paraphrase_multilingual_minilm_l12_v2(),
},
+ {
+ label: 'LFM2.5 Embedding XNNPACK',
+ value: textEmbedding.lfm2_5_embedding_350m({ backend: 'xnnpack' }),
+ },
+ {
+ label: 'LFM2.5 Embedding MLX',
+ value: textEmbedding.lfm2_5_embedding_350m({ backend: 'mlx' }),
+ },
+ {
+ label: 'LFM2.5 ColBERT (late-interaction)',
+ value: textEmbedding.lfm2_5_colbert_350m(),
+ },
+];
+
+const CORPUS: string[] = [
+ 'The forecast says heavy showers this afternoon.',
+ "It's so sunny outside today!",
+ 'A thick fog rolled in over the harbor at dawn.',
+ 'The home team scored in the final minute to win the match.',
+ 'She sprinted the last lap and broke the national record.',
+ 'Fans packed the stadium for the championship game.',
+ 'Simmer the tomatoes with garlic before adding the pasta.',
+ 'He whisked the eggs and folded in the melted chocolate.',
+ 'The new phone has a faster chip and a brighter screen.',
+ 'Our servers crashed under the sudden spike in traffic.',
+ 'The flight to Tokyo was delayed by three hours.',
+ 'We hiked along the coast and camped near the cliffs.',
+];
+
+const EXAMPLE_QUERIES: string[] = [
+ "What's the weather like?",
+ 'Who won the match?',
+ 'Tell me about the latest technology',
+ 'How do I cook dinner?',
+ 'Where did they travel?',
];
-import { useIsFocused } from 'expo-router';
-import { dotProduct } from '../../utils/math';
-import ErrorBanner from '../../components/ErrorBanner';
export default function TextEmbeddingsScreenWrapper() {
const isFocused = useIsFocused();
@@ -54,6 +92,8 @@ export default function TextEmbeddingsScreenWrapper() {
return isFocused ? : null;
}
+type RankedResult = { sentence: string; similarity: number };
+
function TextEmbeddingsScreen() {
const [selectedModel, setSelectedModel] = useState(
textEmbedding.all_minilm_l6_v2()
@@ -61,88 +101,70 @@ function TextEmbeddingsScreen() {
const model = useTextEmbeddings({ model: selectedModel });
const [error, setError] = useState(null);
- const [inputSentence, setInputSentence] = useState('');
- const [sentencesWithEmbeddings, setSentencesWithEmbeddings] = useState<
- { sentence: string; embedding: Float32Array }[]
- >([]);
- const [topMatches, setTopMatches] = useState<
- { sentence: string; similarity: number }[]
+ const isMultiVector = !!selectedModel.multiVector;
+ const skipListIds = selectedModel.skipListIds ?? [];
+
+ const [query, setQuery] = useState('');
+ const [corpusEmbeddings, setCorpusEmbeddings] = useState<
+ { sentence: string; embedding: Encoding }[]
>([]);
+ const [results, setResults] = useState([]);
const [embeddingTime, setEmbeddingTime] = useState(null);
+ const [indexing, setIndexing] = useState(false);
useEffect(
() => {
- const computeEmbeddings = async () => {
+ let cancelled = false;
+ const indexCorpus = async () => {
if (!model.isReady) return;
-
- const sentences = [
- 'The weather is lovely today.',
- "It's so sunny outside!",
- 'He drove to the stadium.',
- ];
-
+ setIndexing(true);
+ setResults([]);
try {
- const embeddings = [];
- for (const sentence of sentences) {
- const embedding = await model.forward(sentence);
- embeddings.push({ sentence, embedding });
+ const embedded = [];
+ for (const sentence of CORPUS) {
+ const embedding = await model.forward(sentence, 'document');
+ if (cancelled) return;
+ embedded.push({ sentence, embedding });
}
-
- setSentencesWithEmbeddings(embeddings);
- } catch (e) {
- setError(e instanceof Error ? e.message : String(e));
+ setCorpusEmbeddings(embedded);
+ } finally {
+ if (!cancelled) setIndexing(false);
}
};
-
- computeEmbeddings();
+ indexCorpus();
+ return () => {
+ cancelled = true;
+ };
},
+
// eslint-disable-next-line react-hooks/exhaustive-deps
- [model.isReady]
+ [model.isReady, selectedModel]
);
- const checkSimilarities = async () => {
- if (!model.isReady || !inputSentence.trim()) return;
-
+ const runSearch = async (queryText: string = query) => {
+ const q = queryText.trim();
+ if (!model.isReady || !q || corpusEmbeddings.length === 0) return;
+ setQuery(queryText);
try {
const start = Date.now();
- const inputEmbedding = await model.forward(inputSentence);
+ const queryEmbedding = (await model.forward(q, 'query')) as Encoding;
setEmbeddingTime(Date.now() - start);
- const matches = sentencesWithEmbeddings.map(
- ({ sentence, embedding }) => ({
+ const ranked = corpusEmbeddings
+ .map(({ sentence, embedding }) => ({
sentence,
- similarity: dotProduct(inputEmbedding, embedding),
- })
- );
- matches.sort((a, b) => b.similarity - a.similarity);
- setTopMatches(matches.slice(0, 3));
- } catch (e) {
- setError(e instanceof Error ? e.message : String(e));
- }
- };
-
- const addToSentences = async () => {
- if (!model.isReady || !inputSentence.trim()) return;
-
- try {
- const start = Date.now();
- const embedding = await model.forward(inputSentence);
- setEmbeddingTime(Date.now() - start);
- setSentencesWithEmbeddings((prev) => [
- ...prev,
- { sentence: inputSentence, embedding },
- ]);
- } catch (e) {
- setError(e instanceof Error ? e.message : String(e));
- }
-
- setInputSentence('');
- setTopMatches([]);
- };
-
- const clearList = async () => {
- if (!model.isReady) return;
- try {
- setSentencesWithEmbeddings([]);
+ similarity: isMultiVector
+ ? maxSim(
+ queryEmbedding as EmbeddingResult,
+ embedding as EmbeddingResult,
+ skipListIds
+ )
+ : dotProduct(
+ queryEmbedding as Float32Array,
+ embedding as Float32Array
+ ),
+ }))
+ .sort((a, b) => b.similarity - a.similarity);
+ setResults(ranked);
} catch (e) {
setError(e instanceof Error ? e.message : String(e));
}
@@ -158,6 +180,9 @@ function TextEmbeddingsScreen() {
return model.isGenerating ? 'Generating...' : 'Model is ready';
};
+ const ready = model.isReady && !indexing && corpusEmbeddings.length > 0;
+ const canSearch = ready && !!query.trim();
+
return (
- Text Embeddings Playground
+ Semantic Search
{getModelStatusText()}
{
setSelectedModel(m);
- setSentencesWithEmbeddings([]);
- setTopMatches([]);
+ setCorpusEmbeddings([]);
+ setResults([]);
+ setQuery('');
}}
/>
setError(null)} />
- List of Existing Sentences
- {sentencesWithEmbeddings.map((item, index) => (
-
- - {item.sentence}
-
- ))}
-
-
- Try Your Sentence
+
+ Search the corpus ({CORPUS.length} sentences)
+
+
+ {isMultiVector
+ ? 'Ranks per-token vectors with MaxSim (late interaction). Ask a full question — tap an example or type your own.'
+ : 'Ranks every sentence by meaning. Ask a full question — tap an example or type your own.'}
+
+
+ {EXAMPLE_QUERIES.map((q) => (
+ runSearch(q)}
+ >
+ {q}
+
+ ))}
+
runSearch()}
+ returnKeyType="search"
/>
-
- runSearch()}
+ style={[
+ styles.buttonPrimary,
+ !canSearch && styles.buttonDisabled,
+ ]}
+ disabled={!canSearch}
+ >
+
+
-
-
- Find Similar
-
-
-
-
-
-
- Add to List
-
-
-
-
-
- Clear List
-
-
-
-
+ {indexing ? 'Indexing corpus…' : 'Search'}
+
+
{embeddingTime !== null && (
- Embedding time: {embeddingTime} ms
+ Query embedded in {embeddingTime} ms
)}
- {topMatches.length > 0 && (
-
- Top Matches
- {topMatches.map((item, index) => (
-
- {item.sentence} ({item.similarity.toFixed(2)})
-
- ))}
-
- )}
+
+ {results.length > 0 && (
+
+ Results
+ {results.map((item, index) => (
+
+ ))}
+
+ )}
);
}
+function ResultRow({
+ sentence,
+ similarity,
+ best,
+ rank,
+}: {
+ sentence: string;
+ similarity: number;
+ best: number;
+ rank: number;
+}) {
+ const fraction = best > 0 ? Math.max(0, similarity / best) : 0;
+ return (
+
+
+ {sentence}
+ {similarity.toFixed(2)}
+
+
+
+
+
+ );
+}
+
const styles = StyleSheet.create({
container: {
flex: 1,
@@ -323,11 +345,68 @@ const styles = StyleSheet.create({
marginBottom: 12,
color: '#1E293B',
},
- sentenceText: {
- fontSize: 14,
+ hint: {
+ fontSize: 13,
+ color: '#64748B',
+ marginBottom: 12,
+ lineHeight: 18,
+ },
+ chipRow: {
+ flexDirection: 'row',
+ flexWrap: 'wrap',
+ gap: 8,
+ marginBottom: 12,
+ },
+ chip: {
+ backgroundColor: '#EEF2FF',
+ borderColor: '#C7D2FE',
+ borderWidth: 1,
+ borderRadius: 16,
+ paddingHorizontal: 12,
+ paddingVertical: 6,
+ },
+ chipDisabled: {
+ opacity: 0.4,
+ },
+ chipText: {
+ fontSize: 13,
+ color: 'navy',
+ },
+ resultRow: {
+ marginBottom: 14,
+ },
+ resultHeader: {
+ flexDirection: 'row',
+ justifyContent: 'space-between',
+ alignItems: 'flex-start',
marginBottom: 6,
+ gap: 8,
+ },
+ resultText: {
+ flex: 1,
+ fontSize: 14,
color: '#334155',
},
+ resultScore: {
+ fontSize: 14,
+ fontWeight: '600',
+ color: '#0F172A',
+ fontVariant: ['tabular-nums'],
+ },
+ barTrack: {
+ height: 8,
+ borderRadius: 4,
+ backgroundColor: '#E2E8F0',
+ overflow: 'hidden',
+ },
+ barFill: {
+ height: '100%',
+ borderRadius: 4,
+ backgroundColor: '#94A3B8',
+ },
+ barFillTop: {
+ backgroundColor: 'navy',
+ },
input: {
backgroundColor: '#F1F5F9',
borderRadius: 10,
@@ -338,17 +417,8 @@ const styles = StyleSheet.create({
minHeight: 40,
textAlignVertical: 'top',
},
- buttonContainer: {
- width: '100%',
- gap: 10,
- },
- buttonGroup: {
- flexDirection: 'row',
- justifyContent: 'space-between',
- gap: 10,
- },
buttonPrimary: {
- flex: 1,
+ width: '100%',
backgroundColor: 'navy',
padding: 12,
borderRadius: 10,
@@ -356,17 +426,6 @@ const styles = StyleSheet.create({
alignItems: 'center',
justifyContent: 'center',
},
- buttonSecondary: {
- flex: 1,
- backgroundColor: 'transparent',
- borderWidth: 2,
- borderColor: 'navy',
- padding: 12,
- borderRadius: 10,
- flexDirection: 'row',
- alignItems: 'center',
- justifyContent: 'center',
- },
buttonDisabled: {
backgroundColor: '#f0f0f0',
borderColor: '#d3d3d3',
@@ -376,17 +435,9 @@ const styles = StyleSheet.create({
textAlign: 'center',
fontWeight: '500',
},
- buttonTextOutline: {
- color: 'navy',
- textAlign: 'center',
- fontWeight: '500',
- },
buttonTextDisabled: {
color: 'gray',
},
- topMatchesContainer: {
- marginTop: 20,
- },
statsText: {
fontSize: 13,
color: '#64748B',
diff --git a/apps/text-embeddings/utils/math.ts b/apps/text-embeddings/utils/math.ts
index 50c70d1f92..44248e1658 100644
--- a/apps/text-embeddings/utils/math.ts
+++ b/apps/text-embeddings/utils/math.ts
@@ -1,6 +1,7 @@
import {
RnExecutorchError,
RnExecutorchErrorCode,
+ EmbeddingResult,
} from 'react-native-executorch';
export const dotProduct = (a: Float32Array, b: Float32Array) => {
@@ -17,3 +18,28 @@ export const dotProduct = (a: Float32Array, b: Float32Array) => {
}
return sum;
};
+
+export const maxSim = (
+ query: EmbeddingResult,
+ doc: EmbeddingResult,
+ skipListIds: number[] = []
+) => {
+ const dim = query.embeddingDim;
+ const skip = new Set(skipListIds);
+ let score = 0;
+ for (let qi = 0; qi < query.numTokens; qi++) {
+ const qOff = qi * dim;
+ let best = -Infinity;
+ for (let di = 0; di < doc.numTokens; di++) {
+ if (skip.has(doc.tokenIds[di]!)) continue;
+ const dOff = di * dim;
+ let dot = 0;
+ for (let k = 0; k < dim; k++) {
+ dot += (query.vectors[qOff + k] ?? 0) * (doc.vectors[dOff + k] ?? 0);
+ }
+ if (dot > best) best = dot;
+ }
+ if (best !== -Infinity) score += best;
+ }
+ return score;
+};
diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp
index 76e0fb90c7..dfd9243c48 100644
--- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.cpp
@@ -26,17 +26,15 @@ TokenizerModule::TokenizerModule(
memorySizeLowerBound = std::filesystem::file_size(modelPath);
}
-std::vector TokenizerModule::encode(std::string s) const {
+// When the tokenizer.json defines a post_processor, the underlying HFTokenizer
+// treats non-zero bos/eos as a flag to run it with add_special_token=true (not
+// a literal count). So bos=eos=0 skips special tokens; bos=eos=1 applies them.
+std::vector TokenizerModule::encodeImpl(const std::string &s,
+ int8_t bos, int8_t eos) const {
if (!tokenizer) {
THROW_NOT_LOADED_ERROR();
}
-
- // If the used tokenizer.json has defined post_processor field,
- // setting any of bos or eos arguments to value other than provided constant
- // ( which is 0) will result in running the post_processor with
- // 'add_special_token' flag
- auto encodeResult =
- tokenizer->encode(s, numOfAddedBoSTokens, numOfAddedEoSTokens);
+ auto encodeResult = tokenizer->encode(s, bos, eos);
if (!encodeResult.ok()) {
throw RnExecutorchError(
RnExecutorchErrorCode::TokenizerError,
@@ -46,6 +44,15 @@ std::vector TokenizerModule::encode(std::string s) const {
return encodeResult.get();
}
+std::vector TokenizerModule::encode(std::string s) const {
+ return encodeImpl(s, numOfAddedBoSTokens, numOfAddedEoSTokens);
+}
+
+std::vector
+TokenizerModule::encodeWithSpecialTokens(std::string s) const {
+ return encodeImpl(s, /*bos=*/1, /*eos=*/1);
+}
+
std::string TokenizerModule::decode(std::vector vec,
bool skipSpecialTokens) const {
if (!tokenizer) {
diff --git a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h
index 3c90b25557..0e1356f121 100644
--- a/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h
+++ b/packages/react-native-executorch/common/rnexecutorch/TokenizerModule.h
@@ -13,6 +13,8 @@ class TokenizerModule {
std::shared_ptr callInvoker);
[[nodiscard("Registered non-void function")]] std::vector
encode(std::string s) const;
+ [[nodiscard("Registered non-void function")]] std::vector
+ encodeWithSpecialTokens(std::string s) const;
[[nodiscard("Registered non-void function")]] std::string
decode(std::vector vec, bool skipSpecialTokens) const;
[[nodiscard("Registered non-void function")]] std::string
@@ -24,6 +26,9 @@ class TokenizerModule {
std::size_t getMemoryLowerBound() const noexcept;
private:
+ std::vector encodeImpl(const std::string &s, int8_t bos,
+ int8_t eos) const;
+
std::unique_ptr tokenizer;
std::size_t memorySizeLowerBound{0};
};
diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
index e4209b2f79..fdc87cd9af 100644
--- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
+++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h
@@ -17,6 +17,7 @@
#include
#include
+#include
#include
#include
#include
@@ -707,6 +708,30 @@ getJsiValue(const models::style_transfer::PixelDataResult &result,
return obj;
}
+inline jsi::Value getJsiValue(const models::embeddings::EmbeddingResult &result,
+ jsi::Runtime &runtime) {
+ jsi::Object obj(runtime);
+
+ auto arrayBuffer = jsi::ArrayBuffer(runtime, result.dataPtr);
+ auto float32ArrayCtor =
+ runtime.global().getPropertyAsFunction(runtime, "Float32Array");
+ auto float32Array = float32ArrayCtor.callAsConstructor(runtime, arrayBuffer)
+ .getObject(runtime);
+ obj.setProperty(runtime, "dataPtr", float32Array);
+
+ obj.setProperty(runtime, "numTokens", jsi::Value(result.numTokens));
+ obj.setProperty(runtime, "embeddingDim", jsi::Value(result.embeddingDim));
+
+ auto idsArray = jsi::Array(runtime, result.tokenIds.size());
+ for (size_t i = 0; i < result.tokenIds.size(); ++i) {
+ idsArray.setValueAtIndex(
+ runtime, i, jsi::Value(static_cast(result.tokenIds[i])));
+ }
+ obj.setProperty(runtime, "tokenIds", idsArray);
+
+ return obj;
+}
+
inline jsi::Value getJsiValue(
const rnexecutorch::models::semantic_segmentation::SegmentationResult
&result,
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp
deleted file mode 100644
index bf291136c1..0000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.cpp
+++ /dev/null
@@ -1,19 +0,0 @@
-#include "BaseEmbeddings.h"
-
-#include
-
-namespace rnexecutorch::models::embeddings {
-
-BaseEmbeddings::BaseEmbeddings(const std::string &modelSource,
- std::shared_ptr callInvoker)
- : BaseModel(modelSource, callInvoker) {}
-
-std::shared_ptr
-BaseEmbeddings::postprocess(const Result> &forwardResult) {
- auto forwardResultTensor = forwardResult->at(0).toTensor();
- auto buffer = std::make_shared(
- forwardResultTensor.const_data_ptr(), forwardResultTensor.nbytes());
- return buffer;
-}
-
-} // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h
deleted file mode 100644
index 216d6bf8ce..0000000000
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/BaseEmbeddings.h
+++ /dev/null
@@ -1,17 +0,0 @@
-#pragma once
-
-#include
-
-namespace rnexecutorch::models::embeddings {
-
-class BaseEmbeddings : public BaseModel {
-public:
- BaseEmbeddings(const std::string &modelSource,
- std::shared_ptr callInvoker);
-
-protected:
- std::shared_ptr
- postprocess(const Result> &forwardResult);
-};
-
-}; // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h
new file mode 100644
index 0000000000..f2de1e899a
--- /dev/null
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/Types.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include
+#include
+#include
+#include
+
+namespace rnexecutorch::models::embeddings {
+
+// Text embedding output as a [numTokens, embeddingDim] fp32 matrix. Pooled
+// single-vector models output numTokens == 1 (the exported graph pools + L2-
+// normalizes); multi-vector (late-interaction / ColBERT) models output
+// numTokens == sequence length. The TS layer reduces to a single vector or
+// keeps the per-token matrix based on the model's config. `tokenIds` are the
+// input ids (used JS-side for late-interaction skiplist masking).
+struct EmbeddingResult {
+ std::shared_ptr dataPtr;
+ int32_t numTokens;
+ int32_t embeddingDim;
+ std::vector tokenIds;
+};
+
+} // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp
index ba2c3243b2..6e5982c2a5 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp
@@ -11,12 +11,15 @@ using namespace executorch::extension;
TextEmbeddings::TextEmbeddings(const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr callInvoker)
- : BaseEmbeddings(modelSource, callInvoker),
+ : BaseModel(modelSource, callInvoker),
tokenizer(
std::make_unique(tokenizerSource, callInvoker)) {}
TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) {
- auto inputIds = tokenizer->encode(input);
+ // Apply the tokenizer's post_processor so declared special tokens (e.g. a
+ // BOS prepended via TemplateProcessing) are added. CLS-pooled embedding
+ // models read position 0, so a missing BOS corrupts the pooled vector.
+ auto inputIds = tokenizer->encodeWithSpecialTokens(input);
// Tokenizers-cpp return tokens as int32, but text embedding models require
// int64 as input
std::vector inputIds64;
@@ -40,8 +43,7 @@ void TextEmbeddings::unload() noexcept {
BaseModel::unload();
}
-std::shared_ptr
-TextEmbeddings::generate(const std::string input) {
+EmbeddingResult TextEmbeddings::generate(const std::string input) {
std::scoped_lock lock(inference_mutex_);
auto preprocessed = preprocess(input);
@@ -58,7 +60,41 @@ TextEmbeddings::generate(const std::string input) {
auto forwardResult = BaseModel::forward({tokenIds, attnMask});
CHECK_OK_OR_THROW_FORWARD_ERROR(forwardResult);
- return BaseEmbeddings::postprocess(forwardResult);
+ return buildResult(forwardResult->at(0).toTensor(),
+ std::move(preprocessed.inputIds));
+}
+
+// Output is [1, numTokens, embeddingDim] (numTokens == 1 for pooled models,
+// == sequence length for multi-vector models). Multi-vector consumers index
+// tokenIds[i] per output row (e.g. skiplist masking), so numTokens must match
+// the input token count or that alignment silently breaks.
+EmbeddingResult
+TextEmbeddings::buildResult(const executorch::aten::Tensor &output,
+ std::vector tokenIds) {
+ auto sizes = output.sizes();
+ if (sizes.size() < 2) {
+ throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput,
+ "Embedding output must be at least 2D, got rank " +
+ std::to_string(sizes.size()));
+ }
+
+ const auto numTokens = static_cast(sizes[sizes.size() - 2]);
+ const auto inputTokens = static_cast(tokenIds.size());
+ if (numTokens != 1 && numTokens != inputTokens) {
+ throw RnExecutorchError(
+ RnExecutorchErrorCode::InvalidModelOutput,
+ "Embedding output rows (" + std::to_string(numTokens) +
+ ") != input tokens (" + std::to_string(inputTokens) +
+ "); per-token tokenIds alignment is broken.");
+ }
+
+ return EmbeddingResult{
+ .dataPtr = std::make_shared(output.const_data_ptr(),
+ output.nbytes()),
+ .numTokens = numTokens,
+ .embeddingDim = static_cast(sizes[sizes.size() - 1]),
+ .tokenIds = std::move(tokenIds),
+ };
}
} // namespace rnexecutorch::models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h
index 93d0988c04..02cfefde4d 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h
+++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h
@@ -3,7 +3,8 @@
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
#include
#include
-#include
+#include
+#include
namespace rnexecutorch {
namespace models::embeddings {
@@ -13,13 +14,16 @@ struct TokenIdsWithAttentionMask {
std::vector attentionMask;
};
-class TextEmbeddings final : public BaseEmbeddings {
+class TextEmbeddings final : public BaseModel {
public:
TextEmbeddings(const std::string &modelSource,
const std::string &tokenizerSource,
std::shared_ptr callInvoker);
- [[nodiscard(
- "Registered non-void function")]] std::shared_ptr
+ // Returns the raw [numTokens, embeddingDim] output. Pooled models give
+ // numTokens == 1; multi-vector (late-interaction) models give the full
+ // sequence. The TS layer reduces to a single vector or keeps the matrix
+ // based on the model's config.
+ [[nodiscard("Registered non-void function")]] EmbeddingResult
generate(const std::string input);
void unload() noexcept;
@@ -27,6 +31,8 @@ class TextEmbeddings final : public BaseEmbeddings {
mutable std::mutex inference_mutex_;
std::vector> inputShapes;
TokenIdsWithAttentionMask preprocess(const std::string &input);
+ static EmbeddingResult buildResult(const executorch::aten::Tensor &output,
+ std::vector tokenIds);
std::unique_ptr tokenizer;
};
} // namespace models::embeddings
diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp
index 68a9a9fef4..3bf5fa2206 100644
--- a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp
@@ -16,9 +16,10 @@ Encoder::Encoder(const std::string &tokenizerSource,
encoderSource, tokenizerSource, callInvoker)) {}
std::vector Encoder::generate(std::string input) {
- std::shared_ptr embeddingsText = encoder->generate(input);
+ std::shared_ptr embeddingsText =
+ encoder->generate(input).dataPtr;
std::shared_ptr embeddingsUncond =
- encoder->generate(std::string(constants::kBosToken));
+ encoder->generate(std::string(constants::kBosToken)).dataPtr;
assert(embeddingsText->size() == embeddingsUncond->size());
size_t embeddingsSize = embeddingsText->size() / sizeof(float);
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
index 5f9d7287a5..a901cd56fc 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt
@@ -218,7 +218,6 @@ add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp
add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/embeddings/image/ImageEmbeddings.cpp
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
${RNEXECUTORCH_DIR}/models/VisionModel.cpp
${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
@@ -230,7 +229,6 @@ add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp
add_rn_test(TextEmbeddingsTests integration/TextEmbeddingsTest.cpp
SOURCES
${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
${TOKENIZER_SOURCES}
LIBS tokenizers_deps
)
@@ -306,7 +304,6 @@ add_rn_test(TextToImageTests integration/TextToImageTest.cpp
${RNEXECUTORCH_DIR}/models/text_to_image/Decoder.cpp
${RNEXECUTORCH_DIR}/models/text_to_image/Scheduler.cpp
${RNEXECUTORCH_DIR}/models/embeddings/text/TextEmbeddings.cpp
- ${RNEXECUTORCH_DIR}/models/embeddings/BaseEmbeddings.cpp
${TOKENIZER_SOURCES}
LIBS tokenizers_deps
)
diff --git a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp
index ff1abd4c30..0e0cc846b5 100644
--- a/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp
+++ b/packages/react-native-executorch/common/rnexecutorch/tests/integration/TextEmbeddingsTest.cpp
@@ -53,23 +53,23 @@ TEST(TextEmbeddingsGenerateTests, EmptyStringReturnsResults) {
TextEmbeddings model(kValidTextEmbeddingsModelPath,
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("");
- EXPECT_NE(result, nullptr);
- EXPECT_GT(result->size(), 0u);
+ EXPECT_NE(result.dataPtr, nullptr);
+ EXPECT_GT(result.dataPtr->size(), 0u);
}
TEST(TextEmbeddingsGenerateTests, ValidTextReturnsResults) {
TextEmbeddings model(kValidTextEmbeddingsModelPath,
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("Hello, world!");
- EXPECT_NE(result, nullptr);
- EXPECT_GT(result->size(), 0u);
+ EXPECT_NE(result.dataPtr, nullptr);
+ EXPECT_GT(result.dataPtr->size(), 0u);
}
TEST(TextEmbeddingsGenerateTests, ResultsHaveCorrectSize) {
TextEmbeddings model(kValidTextEmbeddingsModelPath,
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("This is a test sentence.");
- size_t numFloats = result->size() / sizeof(float);
+ size_t numFloats = result.dataPtr->size() / sizeof(float);
EXPECT_EQ(numFloats, kMiniLmEmbeddingDimensions);
}
@@ -78,8 +78,8 @@ TEST(TextEmbeddingsGenerateTests, ResultsAreNormalized) {
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("The quick brown fox jumps over the lazy dog.");
- const float *data = reinterpret_cast(result->data());
- size_t numFloats = result->size() / sizeof(float);
+ const float *data = reinterpret_cast(result.dataPtr->data());
+ size_t numFloats = result.dataPtr->size() / sizeof(float);
float sumOfSquares = 0.0f;
for (size_t i = 0; i < numFloats; ++i) {
@@ -94,8 +94,8 @@ TEST(TextEmbeddingsGenerateTests, ResultsContainValidValues) {
kValidTextEmbeddingsTokenizerPath, nullptr);
auto result = model.generate("Testing valid values.");
- const float *data = reinterpret_cast(result->data());
- size_t numFloats = result->size() / sizeof(float);
+ const float *data = reinterpret_cast(result.dataPtr->data());
+ size_t numFloats = result.dataPtr->size() / sizeof(float);
for (size_t i = 0; i < numFloats; ++i) {
EXPECT_FALSE(std::isnan(data[i]));
@@ -110,9 +110,9 @@ TEST(TextEmbeddingsGenerateTests, DifferentTextProducesDifferentEmbeddings) {
auto result1 = model.generate("Hello, world!");
auto result2 = model.generate("Goodbye, moon!");
- const float *data1 = reinterpret_cast(result1->data());
- const float *data2 = reinterpret_cast(result2->data());
- size_t numFloats = result1->size() / sizeof(float);
+ const float *data1 = reinterpret_cast(result1.dataPtr->data());
+ const float *data2 = reinterpret_cast(result2.dataPtr->data());
+ size_t numFloats = result1.dataPtr->size() / sizeof(float);
bool allEqual = true;
for (size_t i = 0; i < numFloats; ++i) {
@@ -131,9 +131,9 @@ TEST(TextEmbeddingsGenerateTests, SimilarTextProducesSimilarEmbeddings) {
auto result1 = model.generate("I love programming");
auto result2 = model.generate("I enjoy coding");
- const float *data1 = reinterpret_cast(result1->data());
- const float *data2 = reinterpret_cast(result2->data());
- size_t numFloats = result1->size() / sizeof(float);
+ const float *data1 = reinterpret_cast(result1.dataPtr->data());
+ const float *data2 = reinterpret_cast(result2.dataPtr->data());
+ size_t numFloats = result1.dataPtr->size() / sizeof(float);
float dotProduct = 0.0f;
for (size_t i = 0; i < numFloats; ++i) {
diff --git a/packages/react-native-executorch/src/constants/modelRegistry.ts b/packages/react-native-executorch/src/constants/modelRegistry.ts
index eb0c98dae7..4c36c6a1fa 100644
--- a/packages/react-native-executorch/src/constants/modelRegistry.ts
+++ b/packages/react-native-executorch/src/constants/modelRegistry.ts
@@ -260,6 +260,59 @@ const GEMMA4_E2B_MM_VARIANTS = {
},
};
+// Asymmetric query/document prompts the LFM models are trained with.
+// forward(text, role) auto-prepends these.
+const LFM_EMBEDDING_PROMPTS = { query: 'query: ', document: 'document: ' };
+const LFM_COLBERT_PROMPTS = { query: '[Q] ', document: '[D] ' };
+
+const LFM2_5_EMBEDDING_350M_VARIANTS = {
+ mlx: {
+ base: {
+ modelName: 'lfm2-5-embedding-350m' as const,
+ modelSource: M.LFM2_5_EMBEDDING_350M_MLX_MODEL,
+ tokenizerSource: M.LFM2_5_EMBEDDING_350M_TOKENIZER,
+ prompts: LFM_EMBEDDING_PROMPTS,
+ },
+ },
+ xnnpack: {
+ base: {
+ modelName: 'lfm2-5-embedding-350m' as const,
+ modelSource: M.LFM2_5_EMBEDDING_350M_XNNPACK_MODEL,
+ tokenizerSource: M.LFM2_5_EMBEDDING_350M_TOKENIZER,
+ prompts: LFM_EMBEDDING_PROMPTS,
+ },
+ },
+};
+
+const LFM_COLBERT_SKIP_LIST = [
+ 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524,
+ 535, 536, 537, 538, 539, 540, 541, 568, 569, 570, 571, 572, 573, 600, 601,
+ 602, 603,
+];
+
+const LFM2_5_COLBERT_350M_VARIANTS = {
+ mlx: {
+ base: {
+ modelName: 'lfm2-5-colbert-350m' as const,
+ modelSource: M.LFM2_5_COLBERT_350M_MLX_MODEL,
+ tokenizerSource: M.LFM2_5_COLBERT_350M_TOKENIZER,
+ prompts: LFM_COLBERT_PROMPTS,
+ multiVector: true as const,
+ skipListIds: LFM_COLBERT_SKIP_LIST,
+ },
+ },
+ xnnpack: {
+ base: {
+ modelName: 'lfm2-5-colbert-350m' as const,
+ modelSource: M.LFM2_5_COLBERT_350M_XNNPACK_MODEL,
+ tokenizerSource: M.LFM2_5_COLBERT_350M_TOKENIZER,
+ prompts: LFM_COLBERT_PROMPTS,
+ multiVector: true as const,
+ skipListIds: LFM_COLBERT_SKIP_LIST,
+ },
+ },
+};
+
const EFFICIENTNET_V2_S_VARIANTS = {
xnnpack: {
base: {
@@ -742,6 +795,14 @@ export const models = {
M.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_QUANTIZED
),
clip_vit_base_patch32_text: base(M.CLIP_VIT_BASE_PATCH32_TEXT),
+ lfm2_5_embedding_350m: variant(LFM2_5_EMBEDDING_350M_VARIANTS, {
+ ios: 'mlx',
+ android: 'xnnpack',
+ }),
+ lfm2_5_colbert_350m: variant(LFM2_5_COLBERT_350M_VARIANTS, {
+ ios: 'mlx',
+ android: 'xnnpack',
+ }),
},
image_embedding: {
clip_vit_base_patch32_image: pair(
diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts
index 0e36f812ff..bd6cddf4a3 100644
--- a/packages/react-native-executorch/src/constants/modelUrls.ts
+++ b/packages/react-native-executorch/src/constants/modelUrls.ts
@@ -1195,6 +1195,12 @@ export const DISTILUSE_BASE_MULTILINGUAL_CASED_V2_8DA4W_MODEL = `${URL_PREFIX}-d
export const DISTILUSE_BASE_MULTILINGUAL_CASED_V2_TOKENIZER = `${URL_PREFIX}-distiluse-base-multilingual-cased-v2/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
const PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_QUANTIZED_MODEL = `${URL_PREFIX}-paraphrase-multilingual-MiniLM-L12-v2/${PREVIOUS_VERSION_TAG}/xnnpack/paraphrase_multilingual_minilm_l12_v2_xnnpack_8da4w.pte`;
const PARAPHRASE_MULTILINGUAL_MINILM_L12_V2_TOKENIZER = `${URL_PREFIX}-paraphrase-multilingual-MiniLM-L12-v2/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
+export const LFM2_5_EMBEDDING_350M_XNNPACK_MODEL = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/xnnpack/lfm_2_5_embedding_350m_xnnpack_8da4w.pte`;
+export const LFM2_5_EMBEDDING_350M_MLX_MODEL = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/mlx/lfm_2_5_embedding_350m_mlx_int4.pte`;
+export const LFM2_5_EMBEDDING_350M_TOKENIZER = `${URL_PREFIX}-lfm2.5-embedding-350m/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
+export const LFM2_5_COLBERT_350M_XNNPACK_MODEL = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/xnnpack/lfm_2_5_colbert_350m_xnnpack_8da4w.pte`;
+export const LFM2_5_COLBERT_350M_MLX_MODEL = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/mlx/lfm_2_5_colbert_350m_mlx_int4.pte`;
+export const LFM2_5_COLBERT_350M_TOKENIZER = `${URL_PREFIX}-lfm2.5-colbert-350m/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
const CLIP_VIT_BASE_PATCH32_TEXT_MODEL = `${URL_PREFIX}-clip-vit-base-patch32/${PREVIOUS_VERSION_TAG}/xnnpack/clip_vit_base_patch32_text_xnnpack_fp32.pte`;
const CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER = `${URL_PREFIX}-clip-vit-base-patch32/${PREVIOUS_VERSION_TAG}/tokenizer.json`;
diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts
index 31ee179925..2f100b8cbb 100644
--- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts
+++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts
@@ -1,20 +1,25 @@
import { TextEmbeddingsModule } from '../../modules/natural_language_processing/TextEmbeddingsModule';
import { useModuleFactory } from '../useModuleFactory';
import {
+ EmbeddingRole,
+ ForwardFn,
+ TextEmbeddingsModel,
TextEmbeddingsType,
TextEmbeddingsProps,
} from '../../types/textEmbeddings';
/**
- * React hook for managing a Text Embeddings model instance.
+ * React hook for a Text Embeddings model.
* @category Hooks
- * @param TextEmbeddingsProps - Configuration object containing `model` source and optional `preventLoad` flag.
- * @returns Ready to use Text Embeddings model.
+ * @param TextEmbeddingsProps - `model` source + optional `preventLoad`.
+ * @returns Ready to use embeddings model. `forward` returns the raw
+ * [numTokens, embeddingDim] result; use `toVector` for a single vector.
+ * Models with prompts require a `role` ('query' | 'document') on `forward`.
*/
-export const useTextEmbeddings = ({
+export const useTextEmbeddings = ({
model,
preventLoad = false,
-}: TextEmbeddingsProps): TextEmbeddingsType => {
+}: TextEmbeddingsProps): TextEmbeddingsType => {
const { error, isReady, isGenerating, downloadProgress, runForward } =
useModuleFactory({
factory: (config, onProgress) =>
@@ -24,7 +29,8 @@ export const useTextEmbeddings = ({
preventLoad,
});
- const forward = (input: string) => runForward((inst) => inst.forward(input));
+ const forward = ((input: string, role?: EmbeddingRole) =>
+ runForward((inst) => inst.forward(input, role))) as ForwardFn;
return { error, isReady, isGenerating, downloadProgress, forward };
};
diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts
index 1f190d41f5..34cdf97d8d 100644
--- a/packages/react-native-executorch/src/index.ts
+++ b/packages/react-native-executorch/src/index.ts
@@ -212,6 +212,7 @@ export * from './utils/ResourceFetcher';
export * from './utils/ResourceFetcherUtils';
export * from './utils/BaseResourceFetcherClass';
export * from './utils/llm';
+export * from './utils/textEmbeddings';
export * from './common/Logger';
export * from './utils/llms/context_strategy';
export * from './utils/segmentAnythingPrompts';
diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts
index 27b0e59ceb..abb620e981 100644
--- a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts
+++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts
@@ -1,5 +1,11 @@
import { ResourceSource } from '../../types/common';
-import { TextEmbeddingsModelName } from '../../types/textEmbeddings';
+import {
+ EmbeddingPrompts,
+ EmbeddingResult,
+ EmbeddingRole,
+ TextEmbeddingsModel,
+ TextEmbeddingsModelName,
+} from '../../types/textEmbeddings';
import { ResourceFetcher } from '../../utils/ResourceFetcher';
import { BaseModule } from '../BaseModule';
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
@@ -7,27 +13,35 @@ import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
import { Logger } from '../../common/Logger';
/**
- * Module for generating text embeddings from input text.
+ * Module for text embeddings. `forward` returns a single pooled `Float32Array`
+ * for standard models, or the per-token `EmbeddingResult` for `multiVector`
+ * (late-interaction) models. The native runner always produces the raw
+ * [numTokens, embeddingDim] matrix; the reduction to a single vector happens
+ * here so the common single-vector API stays `Float32Array`.
* @category Typescript API
*/
export class TextEmbeddingsModule extends BaseModule {
- private constructor(nativeModule: unknown) {
+ private prompts?: EmbeddingPrompts;
+ private multiVector: boolean;
+
+ private constructor(
+ nativeModule: unknown,
+ prompts: EmbeddingPrompts | undefined,
+ multiVector: boolean
+ ) {
super();
this.nativeModule = nativeModule;
+ this.prompts = prompts;
+ this.multiVector = multiVector;
}
/**
* Creates a text embeddings instance for a built-in model.
- * @param namedSources - An object specifying which built-in model to load and where to fetch it from.
- * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
- * @returns A Promise resolving to a `TextEmbeddingsModule` instance.
+ * @param namedSources - The model config (+ optional prompts / multiVector).
+ * @param onDownloadProgress - Optional download progress callback (0..1).
*/
static async fromModelName(
- namedSources: {
- modelName: TextEmbeddingsModelName;
- modelSource: ResourceSource;
- tokenizerSource: ResourceSource;
- },
+ namedSources: TextEmbeddingsModel,
onDownloadProgress: (progress: number) => void = () => {}
): Promise {
try {
@@ -41,7 +55,9 @@ export class TextEmbeddingsModule extends BaseModule {
throw new RnExecutorchError(RnExecutorchErrorCode.DownloadInterrupted);
}
return new TextEmbeddingsModule(
- await global.loadTextEmbeddings(modelPath, tokenizerPath)
+ await global.loadTextEmbeddings(modelPath, tokenizerPath),
+ namedSources.prompts,
+ namedSources.multiVector ?? false
);
} catch (error) {
Logger.error('Load failed:', error);
@@ -50,14 +66,9 @@ export class TextEmbeddingsModule extends BaseModule {
}
/**
- * Creates a text embeddings instance with a user-provided model binary and tokenizer.
- * Use this when working with a custom-exported model that is not one of the built-in presets.
- * @remarks The native model contract for this method is not formally defined and may change
- * between releases. Refer to the native source code for the current expected tensor interface.
- * @param modelSource - A fetchable resource pointing to the model binary.
- * @param tokenizerSource - A fetchable resource pointing to the tokenizer file.
- * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
- * @returns A Promise resolving to a `TextEmbeddingsModule` instance.
+ * Creates a text embeddings instance from a custom model binary + tokenizer.
+ * @remarks The native tensor contract is not formally guaranteed across
+ * releases.
*/
static fromCustomModel(
modelSource: ResourceSource,
@@ -75,13 +86,32 @@ export class TextEmbeddingsModule extends BaseModule {
}
/**
- * Executes the model's forward pass to generate an embedding for the provided text.
- * @param input - The text string to embed.
- * @returns A Promise resolving to a `Float32Array` containing the embedding vector.
+ * Embed text. Standard models return the single pooled `Float32Array`;
+ * `multiVector` models return the per-token `EmbeddingResult`.
+ * @param input - The text to embed.
+ * @param role - 'query' | 'document'; prepends the model's prompt for that
+ * role when configured (no-op otherwise).
*/
- async forward(input: string): Promise {
+ async forward(
+ input: string,
+ role?: EmbeddingRole
+ ): Promise {
if (this.nativeModule == null)
throw new RnExecutorchError(RnExecutorchErrorCode.ModuleNotLoaded);
- return new Float32Array(await this.nativeModule.generate(input));
+ const prefix = (role && this.prompts?.[role]) || '';
+ const res = await this.nativeModule.generate(prefix + input);
+ // res.dataPtr is already a Float32Array view over the owned native buffer
+ // (built at the JSI boundary).
+ const vectors = res.dataPtr as Float32Array;
+ if (!this.multiVector) {
+ // Pooled models output [1, embeddingDim]; return that single row.
+ return vectors.subarray(0, res.embeddingDim);
+ }
+ return {
+ vectors,
+ numTokens: res.numTokens,
+ embeddingDim: res.embeddingDim,
+ tokenIds: res.tokenIds,
+ };
}
}
diff --git a/packages/react-native-executorch/src/types/textEmbeddings.ts b/packages/react-native-executorch/src/types/textEmbeddings.ts
index d9cd120e26..1b056a1f7b 100644
--- a/packages/react-native-executorch/src/types/textEmbeddings.ts
+++ b/packages/react-native-executorch/src/types/textEmbeddings.ts
@@ -12,65 +12,124 @@ export type TextEmbeddingsModelName =
| 'multi-qa-mpnet-base-dot-v1'
| 'distiluse-base-multilingual-cased-v2-8da4w'
| 'paraphrase-multilingual-minilm-l12-v2-quantized'
- | 'clip-vit-base-patch32-text';
+ | 'clip-vit-base-patch32-text'
+ | 'lfm2-5-embedding-350m'
+ | 'lfm2-5-colbert-350m';
/**
- * Props for the useTextEmbeddings hook.
+ * Per-token (multi-vector) embedding output for late-interaction models (e.g.
+ * ColBERT): a [numTokens, embeddingDim] fp32 matrix (row-major) plus the input
+ * token ids. Standard models return a single pooled `Float32Array` from
+ * `forward` instead; only `multiVector` models yield this.
* @category Types
- * @property {object} model - An object containing the model configuration.
- * @property {TextEmbeddingsModelName} model.modelName - Unique name identifying the model.
- * @property {ResourceSource} model.modelSource - The source of the text embeddings model binary.
- * @property {ResourceSource} model.tokenizerSource - The source of the tokenizer JSON file.
- * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
*/
-export interface TextEmbeddingsProps {
- model: {
- /**
- * The unique name of the text embeddings model.
- */
- modelName: TextEmbeddingsModelName;
- /**
- * The source of the text embeddings model binary.
- */
- modelSource: ResourceSource;
- /**
- * The source of the tokenizer JSON file.
- */
- tokenizerSource: ResourceSource;
- };
- preventLoad?: boolean;
+export interface EmbeddingResult {
+ /** Flat [numTokens * embeddingDim] fp32 vectors (row-major). */
+ vectors: Float32Array;
+ /** Number of token rows. */
+ numTokens: number;
+ /** Per-token vector dimension. */
+ embeddingDim: number;
+ /** Input token ids per row. */
+ tokenIds: number[];
}
/**
- * React hook state and methods for managing a Text Embeddings model instance.
+ * Role for `forward`. Some models are trained with asymmetric query/document
+ * prompts (e.g. LFM2.5 uses `query: `/`document: `, ColBERT uses `[Q] `/`[D] `).
+ * Passing a role auto-prepends the model's configured prompt for that role.
* @category Types
*/
-export interface TextEmbeddingsType {
- /**
- * Contains the error message if the model failed to load or during inference.
- */
- error: null | RnExecutorchError;
+export type EmbeddingRole = 'query' | 'document';
- /**
- * Indicates whether the embeddings model has successfully loaded and is ready for inference.
- */
- isReady: boolean;
+/**
+ * Asymmetric prompts a model is trained with. When a model config carries
+ * these, `forward` REQUIRES a `role` so the matching prompt is always applied
+ * (forgetting it would silently embed raw text and wreck asymmetric retrieval).
+ * @category Types
+ */
+export interface EmbeddingPrompts {
+ query: string;
+ document: string;
+}
+/**
+ * A text embeddings model config. Two optional flags drive `forward`:
+ * - `prompts` present -> `forward` REQUIRES a `role` (auto-prepends the prompt)
+ * - `multiVector` true -> `forward` returns the per-token `EmbeddingResult`;
+ * otherwise it returns a single pooled `Float32Array`.
+ * @category Types
+ */
+export interface TextEmbeddingsModel {
+ modelName: TextEmbeddingsModelName;
+ modelSource: ResourceSource;
+ tokenizerSource: ResourceSource;
+ prompts?: EmbeddingPrompts;
+ multiVector?: boolean;
/**
- * Indicates whether the model is currently generating embeddings.
+ * Document token ids to exclude from late-interaction scoring (e.g. ColBERT's
+ * punctuation skipList). Derived from the model's training config, so it's
+ * shipped here rather than reconstructed by the consumer, who passes it to
+ * their own MaxSim scoring.
*/
- isGenerating: boolean;
+ skipListIds?: number[];
+}
- /**
- * Tracks the progress of the model download process (value between 0 and 1).
- */
+/**
+ * `forward`'s signature, computed from the model config:
+ * - return type: `EmbeddingResult` if `multiVector`, else `Float32Array`.
+ * - role arg: required if the model has `prompts`, else absent.
+ */
+export type ForwardReturn = M extends {
+ multiVector: true;
+}
+ ? EmbeddingResult
+ : Float32Array;
+
+/**
+ * `forward`'s signature, computed from the model config:
+ * - A model that DEFINITELY has prompts -> `role` is REQUIRED.
+ * - A model that definitely has NO prompts (`prompts?: undefined`) -> no role.
+ * - Otherwise (prompts optional / unknown, e.g. a heterogeneous model list) ->
+ * `role` is OPTIONAL.
+ */
+export type ForwardFn = M extends {
+ prompts: EmbeddingPrompts;
+}
+ ? (input: string, role: EmbeddingRole) => Promise>
+ : undefined extends M['prompts']
+ ? M['prompts'] extends undefined
+ ? (input: string) => Promise>
+ : (input: string, role?: EmbeddingRole) => Promise>
+ : (input: string) => Promise>;
+
+/**
+ * Props for the useTextEmbeddings hook.
+ * @category Types
+ */
+export interface TextEmbeddingsProps<
+ M extends TextEmbeddingsModel = TextEmbeddingsModel,
+> {
+ model: M;
+ preventLoad?: boolean;
+}
+
+/**
+ * React hook state and methods for a Text Embeddings model instance.
+ * @category Types
+ */
+export interface TextEmbeddingsType<
+ M extends TextEmbeddingsModel = TextEmbeddingsModel,
+> {
+ error: null | RnExecutorchError;
+ isReady: boolean;
+ isGenerating: boolean;
downloadProgress: number;
/**
- * Runs the text embeddings model on the provided input string.
- * @param input - The text string to embed.
- * @returns A promise resolving to a Float32Array containing the vector embeddings.
- * @throws {RnExecutorchError} If the model is not loaded or is currently processing another request.
+ * Embed text. Standard models return a single pooled `Float32Array`;
+ * `multiVector` models return the per-token `EmbeddingResult`. Models with
+ * `prompts` require a `role` ('query' | 'document').
*/
- forward(input: string): Promise;
+ forward: ForwardFn;
}
diff --git a/packages/react-native-executorch/src/utils/textEmbeddings.ts b/packages/react-native-executorch/src/utils/textEmbeddings.ts
new file mode 100644
index 0000000000..e9be7cf774
--- /dev/null
+++ b/packages/react-native-executorch/src/utils/textEmbeddings.ts
@@ -0,0 +1,37 @@
+import { EmbeddingResult } from '../types/textEmbeddings';
+
+/**
+ * Get the single pooled embedding vector from a result. Convenience for the
+ * common single-vector case: the exported graph pools + L2-normalizes to a
+ * [1, embeddingDim] output, so this returns row 0.
+ *
+ * For multi-vector (late-interaction) models, prefer the full per-token
+ * vectors (`getTokenVectors`); row 0 alone is not a meaningful sentence
+ * embedding there.
+ *
+ * @category Utils
+ */
+export function toVector(result: EmbeddingResult): Float32Array {
+ return result.vectors.slice(0, result.embeddingDim);
+}
+
+/**
+ * Split a result's flat `vectors` buffer into per-token rows
+ * (`numTokens` arrays of length `embeddingDim`). Useful for inspecting or
+ * storing individual token vectors (e.g. a multi-vector vector DB).
+ *
+ * The rows are zero-copy `subarray` VIEWS over `result.vectors` — valid only
+ * while that buffer is alive and not mutated. Copy them (e.g. `new
+ * Float32Array(row)`) before storing beyond the result's lifetime. (`toVector`
+ * by contrast returns an independent copy.)
+ *
+ * @category Utils
+ */
+export function getTokenVectors(result: EmbeddingResult): Float32Array[] {
+ const { vectors, numTokens, embeddingDim } = result;
+ const rows: Float32Array[] = [];
+ for (let i = 0; i < numTokens; i++) {
+ rows.push(vectors.subarray(i * embeddingDim, (i + 1) * embeddingDim));
+ }
+ return rows;
+}