Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 117 additions & 29 deletions src/model.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use anyhow::{anyhow, Context, Result};
use half::f16;
use hf_hub::api::sync::Api;
use ndarray::Array2;
use ndarray::{Array2, ArrayView2, CowArray, Ix2};
use safetensors::{tensor::Dtype, SafeTensors};
use serde_json::Value;
use std::borrow::Cow;
use std::{env, fs, path::Path};
use tokenizers::Tokenizer;

/// Static embedding model for Model2Vec
#[derive(Debug, Clone)]
pub struct StaticModel {
tokenizer: Tokenizer,
embeddings: Array2<f32>,
weights: Option<Vec<f32>>,
token_mapping: Option<Vec<usize>>,
embeddings: CowArray<'static, f32, Ix2>,
weights: Option<Cow<'static, [f32]>>,
token_mapping: Option<Cow<'static, [usize]>>,
normalize: bool,
median_token_length: usize,
unk_token_id: Option<usize>,
Expand Down Expand Up @@ -64,32 +65,12 @@ impl StaticModel {
// Load the tokenizer
let tokenizer = Tokenizer::from_file(&tok_path).map_err(|e| anyhow!("failed to load tokenizer: {e}"))?;

// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);

// Read normalize default from config.json
let cfg_file = std::fs::File::open(&cfg_path).context("failed to read config.json")?;
let cfg: Value = serde_json::from_reader(&cfg_file).context("failed to parse config.json")?;
let cfg_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(cfg_norm);

// Serialize the tokenizer to JSON, then parse it and get the unk_token
let spec_json = tokenizer
.to_string(false)
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
let spec: Value = serde_json::from_str(&spec_json)?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str)
.unwrap_or("[UNK]");
let unk_token_id = tokenizer
.token_to_id(unk_token)
.ok_or_else(|| anyhow!("tokenizer claims unk_token='{unk_token}' but it isn't in the vocab"))?
as usize;

// Load the safetensors
let model_bytes = fs::read(&mdl_path).context("failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&model_bytes).context("failed to parse safetensors")?;
Expand All @@ -115,7 +96,6 @@ impl StaticModel {
Dtype::I8 => raw.iter().map(|&b| f32::from(b as i8)).collect(),
other => return Err(anyhow!("unsupported tensor dtype: {other:?}")),
};
let embeddings = Array2::from_shape_vec((rows, cols), floats).context("failed to build embeddings array")?;

// Load optional weights for vocabulary quantization
let weights = match safet.tensor("weights") {
Expand Down Expand Up @@ -154,17 +134,125 @@ impl StaticModel {
Err(_) => None,
};

Self::from_owned(tokenizer, floats, rows, cols, normalize, weights, token_mapping)
}

/// Construct from owned data.
///
/// # Arguments
/// * `tokenizer` - Pre-deserialized tokenizer
/// * `embeddings` - Owned f32 embedding data
/// * `rows` - Number of vocabulary entries
/// * `cols` - Embedding dimension
/// * `normalize` - Whether to L2-normalize output embeddings
/// * `weights` - Optional per-token weights for quantized models
/// * `token_mapping` - Optional token ID mapping for quantized models
pub fn from_owned(
tokenizer: Tokenizer,
embeddings: Vec<f32>,
rows: usize,
cols: usize,
normalize: bool,
weights: Option<Vec<f32>>,
token_mapping: Option<Vec<usize>>,
) -> Result<Self> {
if embeddings.len() != rows * cols {
return Err(anyhow!(
"embeddings length {} != rows {} * cols {}",
embeddings.len(),
rows,
cols
));
}

let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;

let embeddings =
Array2::from_shape_vec((rows, cols), embeddings).context("failed to build embeddings array")?;

Ok(Self {
tokenizer,
embeddings: CowArray::from(embeddings),
weights: weights.map(Cow::Owned),
token_mapping: token_mapping.map(Cow::Owned),
normalize,
median_token_length,
unk_token_id,
})
}

/// Construct from static slices (zero-copy for embedded binary data).
///
/// # Arguments
/// * `tokenizer` - Pre-deserialized tokenizer
/// * `embeddings` - Static f32 embedding data (borrowed, no copy)
/// * `rows` - Number of vocabulary entries
/// * `cols` - Embedding dimension
/// * `normalize` - Whether to L2-normalize output embeddings
/// * `weights` - Optional static per-token weights for quantized models
/// * `token_mapping` - Optional static token ID mapping for quantized models
#[allow(dead_code)] // Public API for external crates
pub fn from_borrowed(
tokenizer: Tokenizer,
embeddings: &'static [f32],
rows: usize,
cols: usize,
normalize: bool,
weights: Option<&'static [f32]>,
token_mapping: Option<&'static [usize]>,
) -> Result<Self> {
if embeddings.len() != rows * cols {
return Err(anyhow!(
"embeddings length {} != rows {} * cols {}",
embeddings.len(),
rows,
cols
));
}

let (median_token_length, unk_token_id) = Self::compute_metadata(&tokenizer)?;

let embeddings = ArrayView2::from_shape((rows, cols), embeddings).context("failed to build embeddings view")?;

Ok(Self {
tokenizer,
embeddings,
weights,
token_mapping,
embeddings: CowArray::from(embeddings),
weights: weights.map(Cow::Borrowed),
token_mapping: token_mapping.map(Cow::Borrowed),
normalize,
median_token_length,
unk_token_id: Some(unk_token_id),
unk_token_id,
})
}

/// Compute median token length and unk_token_id from tokenizer.
fn compute_metadata(tokenizer: &Tokenizer) -> Result<(usize, Option<usize>)> {
// Median-token-length hack for pre-truncation
let mut lens: Vec<usize> = tokenizer.get_vocab(false).keys().map(|tk| tk.len()).collect();
lens.sort_unstable();
let median_token_length = lens.get(lens.len() / 2).copied().unwrap_or(1);

// Get unk_token from tokenizer (optional - BPE tokenizers may not have one)
let spec_json = tokenizer
.to_string(false)
.map_err(|e| anyhow!("tokenizer -> JSON failed: {e}"))?;
let spec: Value = serde_json::from_str(&spec_json)?;
let unk_token = spec
.get("model")
.and_then(|m| m.get("unk_token"))
.and_then(Value::as_str);
let unk_token_id = if let Some(tok) = unk_token {
let id = tokenizer
.token_to_id(tok)
.ok_or_else(|| anyhow!("tokenizer declares unk_token='{tok}' but it isn't in the vocab"))?;
Some(id as usize)
} else {
None
};

Ok((median_token_length, unk_token_id))
}

/// Char-level truncation to max_tokens * median_token_length
fn truncate_str(s: &str, max_tokens: usize, median_len: usize) -> &str {
let max_chars = max_tokens.saturating_mul(median_len);
Expand Down
27 changes: 27 additions & 0 deletions tests/test_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,30 @@ fn test_normalization_flag_override() {
"Without normalization override, norm should be larger"
);
}

/// Test from_borrowed constructor (zero-copy path)
#[test]
fn test_from_borrowed() {
use safetensors::SafeTensors;
use std::fs;
use tokenizers::Tokenizer;

let path = "tests/fixtures/test-model-float32";
let tokenizer = Tokenizer::from_file(format!("{path}/tokenizer.json")).unwrap();
let bytes = fs::read(format!("{path}/model.safetensors")).unwrap();
let tensors = SafeTensors::deserialize(&bytes).unwrap();
let tensor = tensors.tensor("embeddings").unwrap();
let [rows, cols]: [usize; 2] = tensor.shape().try_into().unwrap();
let floats: Vec<f32> = tensor
.data()
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect();

// Leak to get 'static lifetime (fine for tests)
let floats: &'static [f32] = Box::leak(floats.into_boxed_slice());

let model = StaticModel::from_borrowed(tokenizer, floats, rows, cols, true, None, None).unwrap();
let emb = model.encode_single("hello");
assert!(!emb.is_empty());
}