From 2918b6b9973b0927121fa2dfeb11ae22bd502710 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sat, 10 Jan 2026 13:51:23 -0300 Subject: [PATCH 1/7] feat: Add `from_bytes()` and `from_raw_parts()` to load model from raw data --- src/model.rs | 115 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 86 insertions(+), 29 deletions(-) diff --git a/src/model.rs b/src/model.rs index d695bb9..13b211b 100644 --- a/src/model.rs +++ b/src/model.rs @@ -33,12 +33,10 @@ impl StaticModel { normalize: Option, subfolder: Option<&str>, ) -> Result { - // If provided, set HF token for authenticated downloads if let Some(tok) = token { env::set_var("HF_HUB_TOKEN", tok); } - // Locate tokenizer.json, model.safetensors, config.json let (tok_path, mdl_path, cfg_path) = { let base = repo_or_path.as_ref(); if base.exists() { @@ -61,38 +59,38 @@ impl StaticModel { } }; - // Load the tokenizer - let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; + let tokenizer_bytes = fs::read(&tok_path).context("failed to read tokenizer.json")?; + let safetensors_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; + let config_bytes = fs::read(&cfg_path).context("failed to read config.json")?; - // Median-token-length hack for pre-truncation - let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); - lens.sort_unstable(); - let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); + Self::from_bytes(&tokenizer_bytes, &safetensors_bytes, &config_bytes, normalize) + } + + /// Load a Model2Vec model from raw bytes. + /// + /// # Arguments + /// * `tokenizer_bytes` - Contents of tokenizer.json + /// * `safetensors_bytes` - Contents of model.safetensors + /// * `config_bytes` - Contents of config.json + /// * `normalize` - Optional flag to override normalization (default from config) + pub fn from_bytes( + tokenizer_bytes: &[u8], + safetensors_bytes: &[u8], + config_bytes: &[u8], + normalize: Option, + ) -> Result { + let tokenizer = + Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; + + let median_token_length = Self::median_token_length(&tokenizer); + let unk_token_id = Self::unk_token_id(&tokenizer)?; - // Read normalize default from config.json - let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?; - let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?; + let cfg: Value = serde_json::from_slice(config_bytes).context("failed to parse config.json")?; let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); let normalize = normalize.unwrap_or(cfg_norm); - // Serialize the tokenizer to JSON, then parse it and get the unk_token - let spec_json = tokenizer - .to_string(false) - .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; - let spec: Value = serde_json::from_str(&spec_json)?; - let unk_token = spec - .get("model") - .and_then(|m| m.get("unk_token")) - .and_then(Value::as_str) - .unwrap_or("[UNK]"); - let unk_token_id = tokenizer - .token_to_id(unk_token) - .ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))? - as usize; - // Load the safetensors - let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; - let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?; + let safet = SafeTensors::deserialize(safetensors_bytes).context("failed to parse safetensors")?; let tensor = safet .tensor("embeddings") .or_else(|_| safet.tensor("0")) @@ -161,10 +159,69 @@ impl StaticModel { token_mapping, normalize, median_token_length, - unk_token_id: Some(unk_token_id), + unk_token_id, }) } + /// Construct from pre-parsed parts. + /// + /// # Arguments + /// * `tokenizer` - Pre-deserialized tokenizer + /// * `embeddings` - Raw f32 embedding data + /// * `rows` - Number of vocabulary entries + /// * `cols` - Embedding dimension + /// * `normalize` - Whether to L2-normalize output embeddings + pub fn from_raw_parts( + tokenizer: Tokenizer, + embeddings: &[f32], + rows: usize, + cols: usize, + normalize: bool, + ) -> Result { + if embeddings.len() != rows * cols { + return Err(anyhow!( + "embeddings length {} != rows {} * cols {}", + embeddings.len(), + rows, + cols + )); + } + + let median_token_length = Self::median_token_length(&tokenizer); + let unk_token_id = Self::unk_token_id(&tokenizer)?; + + let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) + .context("failed to build embeddings array")?; + + Ok(Self { + tokenizer, + embeddings, + weights: None, + token_mapping: None, + normalize, + median_token_length, + unk_token_id, + }) + } + + fn median_token_length(tokenizer: &Tokenizer) -> usize { + let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); + lens.sort_unstable(); + lens.get(lens.len() / 2).copied().unwrap_or(1) + } + + fn unk_token_id(tokenizer: &Tokenizer) -> Result> { + let spec_json = tokenizer + .to_string(false) + .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; + let spec: Value = serde_json::from_str(&spec_json)?; + let unk_token = spec + .get("model") + .and_then(|m| m.get("unk_token")) + .and_then(Value::as_str); + Ok(unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize)) + } + /// Char-level truncation to max_tokens * median_token_length fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str { let max_chars = max_tokens.saturating_mul(median_len); From a08059eb849496050b0103655f0e92c5ce19815a Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sat, 10 Jan 2026 14:36:28 -0300 Subject: [PATCH 2/7] chore: Add test for `from_raw_parts` --- src/model.rs | 1 + tests/test_model.rs | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/model.rs b/src/model.rs index 13b211b..ced76a2 100644 --- a/src/model.rs +++ b/src/model.rs @@ -171,6 +171,7 @@ impl StaticModel { /// * `rows` - Number of vocabulary entries /// * `cols` - Embedding dimension /// * `normalize` - Whether to L2-normalize output embeddings + #[allow(dead_code)] pub fn from_raw_parts( tokenizer: Tokenizer, embeddings: &[f32], diff --git a/tests/test_model.rs b/tests/test_model.rs index f09b8c2..3fbb075 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -70,3 +70,26 @@ fn test_normalization_flag_override() { "Without normalization override, norm should be larger" ); } + +/// Test from_raw_parts constructor +#[test] +fn test_from_raw_parts() { + use std::fs; + use tokenizers::Tokenizer; + use safetensors::SafeTensors; + + let path = "tests/fixtures/test-model-float32"; + let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap(); + let bytes = fs::read(format!("{path}/model.safetensors")).unwrap(); + let tensors = SafeTensors::deserialize(&bytes).unwrap(); + let tensor = tensors.tensor("embeddings").unwrap(); + let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap(); + let floats: Vec = tensor.data() + .chunks_exact(4) + .map(|b| f32::from_le_bytes(b.try_into().unwrap())) + .collect(); + + let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true).unwrap(); + let emb = model.encode_single("hello"); + assert!(!emb.is_empty()); +} From e992ae04d7769d8deb05f454b86a277f09193030 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sat, 10 Jan 2026 14:49:44 -0300 Subject: [PATCH 3/7] feat: Add `from_raw_parts()` constructor - `from_pretrained` now delegates to `from_raw_parts` - Fixes BPE tokenizer support (unk_token_id now optional) --- src/model.rs | 91 +++++++++++++++------------------------------ tests/test_model.rs | 2 +- 2 files changed, 32 insertions(+), 61 deletions(-) diff --git a/src/model.rs b/src/model.rs index ced76a2..dff1a8a 100644 --- a/src/model.rs +++ b/src/model.rs @@ -33,10 +33,12 @@ impl StaticModel { normalize: Option, subfolder: Option<&str>, ) -> Result { + // If provided, set HF token for authenticated downloads if let Some(tok) = token { env::set_var("HF_HUB_TOKEN", tok); } + // Locate tokenizer.json, model.safetensors, config.json let (tok_path, mdl_path, cfg_path) = { let base = repo_or_path.as_ref(); if base.exists() { @@ -59,38 +61,18 @@ impl StaticModel { } }; - let tokenizer_bytes = fs::read(&tok_path).context("failed to read tokenizer.json")?; - let safetensors_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; - let config_bytes = fs::read(&cfg_path).context("failed to read config.json")?; + // Load the tokenizer + let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; - Self::from_bytes(&tokenizer_bytes, &safetensors_bytes, &config_bytes, normalize) - } - - /// Load a Model2Vec model from raw bytes. - /// - /// # Arguments - /// * `tokenizer_bytes` - Contents of tokenizer.json - /// * `safetensors_bytes` - Contents of model.safetensors - /// * `config_bytes` - Contents of config.json - /// * `normalize` - Optional flag to override normalization (default from config) - pub fn from_bytes( - tokenizer_bytes: &[u8], - safetensors_bytes: &[u8], - config_bytes: &[u8], - normalize: Option, - ) -> Result { - let tokenizer = - Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?; - - let median_token_length = Self::median_token_length(&tokenizer); - let unk_token_id = Self::unk_token_id(&tokenizer)?; - - let cfg: Value = serde_json::from_slice(config_bytes).context("failed to parse config.json")?; + // Read normalize default from config.json + let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?; + let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?; let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); let normalize = normalize.unwrap_or(cfg_norm); // Load the safetensors - let safet = SafeTensors::deserialize(safetensors_bytes).context("failed to parse safetensors")?; + let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?; + let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?; let tensor = safet .tensor("embeddings") .or_else(|_| safet.tensor("0")) @@ -113,7 +95,6 @@ impl StaticModel { Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(), other => return Err(anyhow!("unsupported tensor dtype: {other:?}")), }; - let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?; // Load optional weights for vocabulary quantization let weights = match safet.tensor("weights") { @@ -152,15 +133,7 @@ impl StaticModel { Err(_) => None, }; - Ok(Self { - tokenizer, - embeddings, - weights, - token_mapping, - normalize, - median_token_length, - unk_token_id, - }) + Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping) } /// Construct from pre-parsed parts. @@ -171,13 +144,16 @@ impl StaticModel { /// * `rows` - Number of vocabulary entries /// * `cols` - Embedding dimension /// * `normalize` - Whether to L2-normalize output embeddings - #[allow(dead_code)] + /// * `weights` - Optional per-token weights for quantized models + /// * `token_mapping` - Optional token ID mapping for quantized models pub fn from_raw_parts( tokenizer: Tokenizer, embeddings: &[f32], rows: usize, cols: usize, normalize: bool, + weights: Option>, + token_mapping: Option>, ) -> Result { if embeddings.len() != rows * cols { return Err(anyhow!( @@ -188,30 +164,12 @@ impl StaticModel { )); } - let median_token_length = Self::median_token_length(&tokenizer); - let unk_token_id = Self::unk_token_id(&tokenizer)?; - - let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) - .context("failed to build embeddings array")?; - - Ok(Self { - tokenizer, - embeddings, - weights: None, - token_mapping: None, - normalize, - median_token_length, - unk_token_id, - }) - } - - fn median_token_length(tokenizer: &Tokenizer) -> usize { + // Median-token-length hack for pre-truncation let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); lens.sort_unstable(); - lens.get(lens.len() / 2).copied().unwrap_or(1) - } + let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1); - fn unk_token_id(tokenizer: &Tokenizer) -> Result> { + // Get unk_token from tokenizer (optional - BPE tokenizers may not have one) let spec_json = tokenizer .to_string(false) .map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?; @@ -220,7 +178,20 @@ impl StaticModel { .get("model") .and_then(|m| m.get("unk_token")) .and_then(Value::as_str); - Ok(unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize)) + let unk_token_id = unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize); + + let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) + .context("failed to build embeddings array")?; + + Ok(Self { + tokenizer, + embeddings, + weights, + token_mapping, + normalize, + median_token_length, + unk_token_id, + }) } /// Char-level truncation to max_tokens * median_token_length diff --git a/tests/test_model.rs b/tests/test_model.rs index 3fbb075..abb3761 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -89,7 +89,7 @@ fn test_from_raw_parts() { .map(|b| f32::from_le_bytes(b.try_into().unwrap())) .collect(); - let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true).unwrap(); + let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true, None, None).unwrap(); let emb = model.encode_single("hello"); assert!(!emb.is_empty()); } From 1fa2f343647aca7899ff73b5f840825efe2a82ae Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sun, 25 Jan 2026 22:29:44 -0300 Subject: [PATCH 4/7] Reformat --- src/model.rs | 8 +++++--- tests/test_model.rs | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/model.rs b/src/model.rs index dff1a8a..0cc17eb 100644 --- a/src/model.rs +++ b/src/model.rs @@ -178,10 +178,12 @@ impl StaticModel { .get("model") .and_then(|m| m.get("unk_token")) .and_then(Value::as_str); - let unk_token_id = unk_token.and_then(|tok| tokenizer.token_to_id(tok)).map(|id| id as usize); + let unk_token_id = unk_token + .and_then(|tok| tokenizer.token_to_id(tok)) + .map(|id| id as usize); - let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()) - .context("failed to build embeddings array")?; + let embeddings = + Array2::from_shape_vec((rows, cols), embeddings.to_vec()).context("failed to build embeddings array")?; Ok(Self { tokenizer, diff --git a/tests/test_model.rs b/tests/test_model.rs index abb3761..48db2c8 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -74,9 +74,9 @@ fn test_normalization_flag_override() { /// Test from_raw_parts constructor #[test] fn test_from_raw_parts() { + use safetensors::SafeTensors; use std::fs; use tokenizers::Tokenizer; - use safetensors::SafeTensors; let path = "tests/fixtures/test-model-float32"; let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap(); @@ -84,7 +84,8 @@ fn test_from_raw_parts() { let tensors = SafeTensors::deserialize(&bytes).unwrap(); let tensor = tensors.tensor("embeddings").unwrap(); let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap(); - let floats: Vec = tensor.data() + let floats: Vec = tensor + .data() .chunks_exact(4) .map(|b| f32::from_le_bytes(b.try_into().unwrap())) .collect(); From 7c6d33407672126799d9ee25645679cd49c31ae9 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sun, 25 Jan 2026 22:37:47 -0300 Subject: [PATCH 5/7] fix: Error if `unk_token` declared but missing from vocab --- src/model.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/model.rs b/src/model.rs index 0cc17eb..120f6d6 100644 --- a/src/model.rs +++ b/src/model.rs @@ -178,9 +178,14 @@ impl StaticModel { .get("model") .and_then(|m| m.get("unk_token")) .and_then(Value::as_str); - let unk_token_id = unk_token - .and_then(|tok| tokenizer.token_to_id(tok)) - .map(|id| id as usize); + let unk_token_id = if let Some(tok) = unk_token { + let id = tokenizer + .token_to_id(tok) + .ok_or_else(|| anyhow!("tokenizer declares unk_token='{tok}' but it isn't in the vocab"))?; + Some(id as usize) + } else { + None + }; let embeddings = Array2::from_shape_vec((rows, cols), embeddings.to_vec()).context("failed to build embeddings array")?; From e04db6b8fee7130bb51056835abc254d688d7dc5 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sun, 25 Jan 2026 22:38:29 -0300 Subject: [PATCH 6/7] refactor: Take ownership of embeddings in `from_raw_parts` Fix unnecessary clone --- src/model.rs | 8 ++++---- tests/test_model.rs | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/model.rs b/src/model.rs index 120f6d6..43857f7 100644 --- a/src/model.rs +++ b/src/model.rs @@ -133,14 +133,14 @@ impl StaticModel { Err(_) => None, }; - Self::from_raw_parts(tokenizer, &floats, rows, cols, normalize, weights, token_mapping) + Self::from_raw_parts(tokenizer, floats, rows, cols, normalize, weights, token_mapping) } /// Construct from pre-parsed parts. /// /// # Arguments /// * `tokenizer` - Pre-deserialized tokenizer - /// * `embeddings` - Raw f32 embedding data + /// * `embeddings` - Raw f32 embedding data (takes ownership to avoid copy) /// * `rows` - Number of vocabulary entries /// * `cols` - Embedding dimension /// * `normalize` - Whether to L2-normalize output embeddings @@ -148,7 +148,7 @@ impl StaticModel { /// * `token_mapping` - Optional token ID mapping for quantized models pub fn from_raw_parts( tokenizer: Tokenizer, - embeddings: &[f32], + embeddings: Vec, rows: usize, cols: usize, normalize: bool, @@ -188,7 +188,7 @@ impl StaticModel { }; let embeddings = - Array2::from_shape_vec((rows, cols), embeddings.to_vec()).context("failed to build embeddings array")?; + Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?; Ok(Self { tokenizer, diff --git a/tests/test_model.rs b/tests/test_model.rs index 48db2c8..0bba5f6 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -90,7 +90,7 @@ fn test_from_raw_parts() { .map(|b| f32::from_le_bytes(b.try_into().unwrap())) .collect(); - let model = StaticModel::from_raw_parts(tokenizer, &floats, rows, cols, true, None, None).unwrap(); + let model = StaticModel::from_raw_parts(tokenizer, floats, rows, cols, true, None, None).unwrap(); let emb = model.encode_single("hello"); assert!(!emb.is_empty()); } From 54c7a49ae0d834ec986e11cd50590d4cbc97f851 Mon Sep 17 00:00:00 2001 From: Sergei Zharinov Date: Sun, 25 Jan 2026 23:34:43 -0300 Subject: [PATCH 7/7] feat: Add `from_borrowed` for zero-copy, `from_owned` for custom loading --- src/model.rs | 92 +++++++++++++++++++++++++++++++++++---------- tests/test_model.rs | 9 +++-- 2 files changed, 78 insertions(+), 23 deletions(-) diff --git a/src/model.rs b/src/model.rs index 43857f7..a4e22e9 100644 --- a/src/model.rs +++ b/src/model.rs @@ -1,9 +1,10 @@ use anyhow::{anyhow, Context, Result}; use half::f16; use hf_hub::api::sync::Api; -use ndarray::Array2; +use ndarray::{Array2, ArrayView2, CowArray, Ix2}; use safetensors::{tensor::Dtype, SafeTensors}; use serde_json::Value; +use std::borrow::Cow; use std::{env, fs, path::Path}; use tokenizers::Tokenizer; @@ -11,9 +12,9 @@ use tokenizers::Tokenizer; #[derive(Debug, Clone)] pub struct StaticModel { tokenizer: Tokenizer, - embeddings: Array2, - weights: Option>, - token_mapping: Option>, + embeddings: CowArray<'static, f32, Ix2>, + weights: Option>, + token_mapping: Option>, normalize: bool, median_token_length: usize, unk_token_id: Option, @@ -133,20 +134,20 @@ impl StaticModel { Err(_) => None, }; - Self::from_raw_parts(tokenizer, floats, rows, cols, normalize, weights, token_mapping) + Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping) } - /// Construct from pre-parsed parts. + /// Construct from owned data. /// /// # Arguments /// * `tokenizer` - Pre-deserialized tokenizer - /// * `embeddings` - Raw f32 embedding data (takes ownership to avoid copy) + /// * `embeddings` - Owned f32 embedding data /// * `rows` - Number of vocabulary entries /// * `cols` - Embedding dimension /// * `normalize` - Whether to L2-normalize output embeddings /// * `weights` - Optional per-token weights for quantized models /// * `token_mapping` - Optional token ID mapping for quantized models - pub fn from_raw_parts( + pub fn from_owned( tokenizer: Tokenizer, embeddings: Vec, rows: usize, @@ -164,6 +165,68 @@ impl StaticModel { )); } + let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?; + + let embeddings = + Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?; + + Ok(Self { + tokenizer, + embeddings: CowArray::from(embeddings), + weights: weights.map(Cow::Owned), + token_mapping: token_mapping.map(Cow::Owned), + normalize, + median_token_length, + unk_token_id, + }) + } + + /// Construct from static slices (zero-copy for embedded binary data). + /// + /// # Arguments + /// * `tokenizer` - Pre-deserialized tokenizer + /// * `embeddings` - Static f32 embedding data (borrowed, no copy) + /// * `rows` - Number of vocabulary entries + /// * `cols` - Embedding dimension + /// * `normalize` - Whether to L2-normalize output embeddings + /// * `weights` - Optional static per-token weights for quantized models + /// * `token_mapping` - Optional static token ID mapping for quantized models + #[allow(dead_code)] // Public API for external crates + pub fn from_borrowed( + tokenizer: Tokenizer, + embeddings: &'static [f32], + rows: usize, + cols: usize, + normalize: bool, + weights: Option<&'static [f32]>, + token_mapping: Option<&'static [usize]>, + ) -> Result { + if embeddings.len() != rows * cols { + return Err(anyhow!( + "embeddings length {} != rows {} * cols {}", + embeddings.len(), + rows, + cols + )); + } + + let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?; + + let embeddings = ArrayView2::from_shape((rows, cols), embeddings).context("failed to build embeddings view")?; + + Ok(Self { + tokenizer, + embeddings: CowArray::from(embeddings), + weights: weights.map(Cow::Borrowed), + token_mapping: token_mapping.map(Cow::Borrowed), + normalize, + median_token_length, + unk_token_id, + }) + } + + /// Compute median token length and unk_token_id from tokenizer. + fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option)> { // Median-token-length hack for pre-truncation let mut lens: Vec = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect(); lens.sort_unstable(); @@ -187,18 +250,7 @@ impl StaticModel { None }; - let embeddings = - Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?; - - Ok(Self { - tokenizer, - embeddings, - weights, - token_mapping, - normalize, - median_token_length, - unk_token_id, - }) + Ok((median_token_length, unk_token_id)) } /// Char-level truncation to max_tokens * median_token_length diff --git a/tests/test_model.rs b/tests/test_model.rs index 0bba5f6..03581dd 100644 --- a/tests/test_model.rs +++ b/tests/test_model.rs @@ -71,9 +71,9 @@ fn test_normalization_flag_override() { ); } -/// Test from_raw_parts constructor +/// Test from_borrowed constructor (zero-copy path) #[test] -fn test_from_raw_parts() { +fn test_from_borrowed() { use safetensors::SafeTensors; use std::fs; use tokenizers::Tokenizer; @@ -90,7 +90,10 @@ fn test_from_raw_parts() { .map(|b| f32::from_le_bytes(b.try_into().unwrap())) .collect(); - let model = StaticModel::from_raw_parts(tokenizer, floats, rows, cols, true, None, None).unwrap(); + // Leak to get 'static lifetime (fine for tests) + let floats: &'static [f32] = Box::leak(floats.into_boxed_slice()); + + let model = StaticModel::from_borrowed(tokenizer, floats, rows, cols, true, None, None).unwrap(); let emb = model.encode_single("hello"); assert!(!emb.is_empty()); }