From 14ea2a7b65fee12ffb0066c0c81dec45a788198f Mon Sep 17 00:00:00 2001 From: Oliver Mannion <125105+tekumara@users.noreply.github.com> Date: Mon, 1 Jun 2026 11:45:44 +1000 Subject: [PATCH 1/6] Add LateOn late-interaction model --- .../late_interaction_text_embedding.py | 3 +- fastembed/late_interaction/lateon.py | 128 ++++++++++++++++++ tests/test_late_interaction_embeddings.py | 21 +++ 3 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 fastembed/late_interaction/lateon.py diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 30c8b70d6..7ccecab35 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -6,13 +6,14 @@ from fastembed.common import OnnxProvider from fastembed.late_interaction.colbert import Colbert from fastembed.late_interaction.jina_colbert import JinaColbert +from fastembed.late_interaction.lateon import LateOn from fastembed.late_interaction.late_interaction_embedding_base import ( LateInteractionTextEmbeddingBase, ) class LateInteractionTextEmbedding(LateInteractionTextEmbeddingBase): - EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [Colbert, JinaColbert] + EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [Colbert, JinaColbert, LateOn] @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: diff --git a/fastembed/late_interaction/lateon.py b/fastembed/late_interaction/lateon.py new file mode 100644 index 000000000..7c5816bc0 --- /dev/null +++ b/fastembed/late_interaction/lateon.py @@ -0,0 +1,128 @@ +import string +from typing import Any, Iterable, Type + +import numpy as np + +from fastembed.common.model_description import DenseModelDescription, ModelSource +from fastembed.common.onnx_model import OnnxOutputContext +from fastembed.common.types import NumpyArray +from fastembed.common.preprocessor_utils import load_tokenizer +from fastembed.common.utils import iter_batch +from fastembed.late_interaction.colbert import Colbert, ColbertEmbeddingWorker + + +supported_lateon_models: list[DenseModelDescription] = [ + DenseModelDescription( + model="lightonai/LateOn", + dim=128, + description=( + "PyLate/ColBERT late-interaction English model based on ModernBERT, " + "300 document tokens, 32 query tokens, 2025 year" + ), + license="apache-2.0", + size_in_GB=0.616, + sources=ModelSource(hf="lightonai/LateOn"), + model_file="model.onnx", + additional_files=["onnx_config.json"], + ), +] + + +class LateOn(Colbert): + QUERY_MARKER_TOKEN_ID = 50368 + DOCUMENT_MARKER_TOKEN_ID = 50369 + QUERY_LENGTH = 32 + DOCUMENT_LENGTH = 300 + MASK_TOKEN = "[MASK]" + + @classmethod + def _get_worker_class(cls) -> Type[ColbertEmbeddingWorker]: + return LateOnEmbeddingWorker + + @classmethod + def _list_supported_models(cls) -> list[DenseModelDescription]: + """Lists the supported LateOn models.""" + return supported_lateon_models + + def load_onnx_model(self) -> None: + self._load_onnx_model( + model_dir=self._model_dir, + model_file=self.model_description.model_file, + threads=self.threads, + providers=self.providers, + cuda=self.cuda, + device_id=self.device_id, + extra_session_options=self._extra_session_options, + ) + self.query_tokenizer, _ = load_tokenizer(model_dir=self._model_dir) + + assert self.tokenizer is not None + self.mask_token_id = self.special_token_to_id[self.MASK_TOKEN] + self.pad_token_id = self.mask_token_id + self.skip_list = { + self.tokenizer.encode(symbol, add_special_tokens=False).ids[0] + for symbol in string.punctuation + } + # LateOn's PyLate config uses document_length/query_length including the inserted + # [D]/[Q] prefix token. Configure the tokenizer for the pre-prefix lengths. + self.tokenizer.enable_truncation(max_length=self.DOCUMENT_LENGTH - 1) + self.query_tokenizer.enable_truncation(max_length=self.QUERY_LENGTH - 1) + + def _post_process_onnx_output( + self, output: OnnxOutputContext, is_doc: bool = True, **kwargs: Any + ) -> Iterable[NumpyArray]: + if is_doc: + yield from super()._post_process_onnx_output(output, is_doc=is_doc, **kwargs) + return + + if output.attention_mask is None: + raise ValueError("attention_mask must be provided for query post-processing") + + for embedding, attention_mask in zip(output.model_output, output.attention_mask): + # LateOn was exported with do_query_expansion=false, so query embeddings should + # only include non-padding query tokens instead of ColBERT mask-token expansion. + embedding = embedding[attention_mask == 1] + norm = np.linalg.norm(embedding, ord=2, axis=1, keepdims=True) + norm_clamped = np.maximum(norm, 1e-12) + yield embedding / norm_clamped + + def token_count( + self, + texts: str | Iterable[str], + batch_size: int = 1024, + is_doc: bool = True, + include_extension: bool = False, + **kwargs: Any, + ) -> int: + if is_doc: + return super().token_count( + texts, + batch_size=batch_size, + is_doc=is_doc, + include_extension=include_extension, + **kwargs, + ) + + if not hasattr(self, "model") or self.model is None: + self.load_onnx_model() + + token_num = 0 + texts = [texts] if isinstance(texts, str) else texts + assert self.query_tokenizer is not None + for batch in iter_batch(texts, batch_size): + for tokens in self.query_tokenizer.encode_batch(batch): + token_num += sum(tokens.attention_mask) + if include_extension: + token_num += len(batch) # add one [Q] prefix token per query + + return token_num + + +class LateOnEmbeddingWorker(ColbertEmbeddingWorker): + def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> LateOn: + return LateOn( + model_name=model_name, + cache_dir=cache_dir, + threads=1, + **kwargs, + ) diff --git a/tests/test_late_interaction_embeddings.py b/tests/test_late_interaction_embeddings.py index ea83e76a3..df4e03fb6 100644 --- a/tests/test_late_interaction_embeddings.py +++ b/tests/test_late_interaction_embeddings.py @@ -38,6 +38,15 @@ [0.0766, 0.0452, -0.2343, -0.183, 0.0058], ] ), + "lightonai/LateOn": np.array( + [ + [0.00039, 0.00651, 0.0146, 0.00346, 0.00244], + [-0.0029, 0.00423, 0.00042, 0.02236, 0.00981], + [-0.0287, 0.01159, 0.02401, -0.00312, -0.04338], + [-0.04709, 0.00209, 0.02174, -0.00381, -0.00608], + [-0.02461, -0.02876, 0.03014, -0.0035, -0.00431], + ] + ), } CANONICAL_QUERY_VALUES = { @@ -149,6 +158,15 @@ [0.0204, -0.0856, -0.0386, -0.1232, -0.0332], ] ), + "lightonai/LateOn": np.array( + [ + [0.00202, -0.02634, 0.00685, 0.00993, 0.03093], + [-0.02321, -0.0226, 0.00356, 0.02836, 0.01729], + [-0.01066, 0.00595, 0.02884, 0.00267, -0.10405], + [-0.10359, -0.06927, 0.03218, 0.05037, -0.03338], + [-0.02992, -0.03874, 0.10582, 0.06303, 0.05831], + ] + ), } _MODELS_TO_CACHE = ("answerdotai/answerai-colbert-small-v1",) @@ -296,6 +314,9 @@ def test_get_embedding_size(): model_name = "answerdotai/answerai-ColBERT-small-v1" assert LateInteractionTextEmbedding.get_embedding_size(model_name) == 96 + model_name = "lightonai/LateOn" + assert LateInteractionTextEmbedding.get_embedding_size(model_name) == 128 + def test_embedding_size(): is_ci = os.getenv("CI") From e5d287d7d345045d124519eb6c4361bcc27ee741 Mon Sep 17 00:00:00 2001 From: Oliver Mannion <125105+tekumara@users.noreply.github.com> Date: Mon, 1 Jun 2026 12:50:59 +1000 Subject: [PATCH 2/6] ruff format --- .../late_interaction/late_interaction_text_embedding.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fastembed/late_interaction/late_interaction_text_embedding.py b/fastembed/late_interaction/late_interaction_text_embedding.py index 7ccecab35..c57b044ab 100644 --- a/fastembed/late_interaction/late_interaction_text_embedding.py +++ b/fastembed/late_interaction/late_interaction_text_embedding.py @@ -13,7 +13,11 @@ class LateInteractionTextEmbedding(LateInteractionTextEmbeddingBase): - EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [Colbert, JinaColbert, LateOn] + EMBEDDINGS_REGISTRY: list[Type[LateInteractionTextEmbeddingBase]] = [ + Colbert, + JinaColbert, + LateOn, + ] @classmethod def list_supported_models(cls) -> list[dict[str, Any]]: From fb94a0da0cb4978affad468bb9e9b501e40e569c Mon Sep 17 00:00:00 2001 From: Oliver Mannion <125105+tekumara@users.noreply.github.com> Date: Mon, 1 Jun 2026 12:52:22 +1000 Subject: [PATCH 3/6] compare against pylate --- tests/test_lateon_pylate_reference.py | 136 ++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 tests/test_lateon_pylate_reference.py diff --git a/tests/test_lateon_pylate_reference.py b/tests/test_lateon_pylate_reference.py new file mode 100644 index 000000000..26b0aebf6 --- /dev/null +++ b/tests/test_lateon_pylate_reference.py @@ -0,0 +1,136 @@ +from typing import Any, Iterable + +import numpy as np +import pytest + +from fastembed import LateInteractionTextEmbedding +from fastembed.common.types import NumpyArray + + +LATEON_MODEL_NAME = "lightonai/LateOn" +REFERENCE_TEXTS = [ + "Hello World", + "Late interaction models compare token embeddings!", +] +RETRIEVAL_QUERIES = [ + "Which animal purrs?", + "What is the capital of France?", +] +RETRIEVAL_DOCUMENT_IDS = ["python", "cat", "paris"] +RETRIEVAL_DOCUMENTS = [ + "Python is a programming language used for machine learning.", + "Cats are small animals that often purr when they are happy.", + "Paris is the capital and largest city of France.", +] + + +def _as_numpy_arrays(embeddings: Iterable[NumpyArray]) -> list[np.ndarray]: + return [np.asarray(embedding, dtype=np.float32) for embedding in embeddings] + + +def _pylate_model(): + pylate_models = pytest.importorskip( + "pylate.models", reason="PyLate is required for the LateOn reference test" + ) + pytest.importorskip("torch", reason="PyLate reference inference requires PyTorch") + return pylate_models.ColBERT(model_name_or_path=LATEON_MODEL_NAME) + + +def _pylate_reference_embeddings(texts: list[str], is_query: bool) -> list[np.ndarray]: + model = _pylate_model() + embeddings = model.encode( + texts, + batch_size=2, + is_query=is_query, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=True, + device="cpu", + ) + return _as_numpy_arrays(embeddings) + + +def _rerank( + queries_embeddings: list[np.ndarray], documents_embeddings: list[np.ndarray] +) -> list[list[Any]]: + pylate_rank = pytest.importorskip( + "pylate.rank", reason="PyLate is required for the LateOn retrieval reference test" + ) + return pylate_rank.rerank( + documents_ids=[RETRIEVAL_DOCUMENT_IDS] * len(queries_embeddings), + queries_embeddings=queries_embeddings, + documents_embeddings=[documents_embeddings] * len(queries_embeddings), + device="cpu", + ) + + +def _result_id(result: Any) -> str: + return result["id"] if isinstance(result, dict) else result.id + + +def _result_score(result: Any) -> float: + return result["score"] if isinstance(result, dict) else result.score + + +@pytest.mark.parametrize("is_query", [False, True]) +def test_lateon_matches_pylate_reference(is_query: bool) -> None: + pylate_embeddings = _pylate_reference_embeddings(REFERENCE_TEXTS, is_query=is_query) + + fastembed_model = LateInteractionTextEmbedding(LATEON_MODEL_NAME, threads=1) + fastembed_embeddings = _as_numpy_arrays( + fastembed_model.query_embed(REFERENCE_TEXTS) + if is_query + else fastembed_model.embed(REFERENCE_TEXTS, batch_size=2) + ) + + assert len(fastembed_embeddings) == len(pylate_embeddings) + for fastembed_embedding, pylate_embedding in zip(fastembed_embeddings, pylate_embeddings): + assert fastembed_embedding.shape == pylate_embedding.shape + assert np.allclose(fastembed_embedding, pylate_embedding, rtol=1e-3, atol=1e-4) + + +def test_lateon_retrieval_matches_pylate_reference() -> None: + pylate_model = _pylate_model() + pylate_query_embeddings = _as_numpy_arrays( + pylate_model.encode( + RETRIEVAL_QUERIES, + batch_size=2, + is_query=True, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=True, + device="cpu", + ) + ) + pylate_document_embeddings = _as_numpy_arrays( + pylate_model.encode( + RETRIEVAL_DOCUMENTS, + batch_size=2, + is_query=False, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=True, + device="cpu", + ) + ) + + fastembed_model = LateInteractionTextEmbedding(LATEON_MODEL_NAME, threads=1) + fastembed_query_embeddings = _as_numpy_arrays(fastembed_model.query_embed(RETRIEVAL_QUERIES)) + fastembed_document_embeddings = _as_numpy_arrays( + fastembed_model.embed(RETRIEVAL_DOCUMENTS, batch_size=2) + ) + + pylate_results = _rerank(pylate_query_embeddings, pylate_document_embeddings) + fastembed_results = _rerank(fastembed_query_embeddings, fastembed_document_embeddings) + + assert len(fastembed_results) == len(pylate_results) + for fastembed_query_results, pylate_query_results in zip(fastembed_results, pylate_results): + assert [_result_id(result) for result in fastembed_query_results] == [ + _result_id(result) for result in pylate_query_results + ] + assert np.allclose( + [_result_score(result) for result in fastembed_query_results], + [_result_score(result) for result in pylate_query_results], + rtol=1e-3, + atol=1e-3, + ) From 9e402b3597cfb61176940595bad04f98dd2c4d3a Mon Sep 17 00:00:00 2001 From: Oliver Mannion <125105+tekumara@users.noreply.github.com> Date: Mon, 1 Jun 2026 12:52:43 +1000 Subject: [PATCH 4/6] remove pylate comparision --- tests/test_lateon_pylate_reference.py | 136 -------------------------- 1 file changed, 136 deletions(-) delete mode 100644 tests/test_lateon_pylate_reference.py diff --git a/tests/test_lateon_pylate_reference.py b/tests/test_lateon_pylate_reference.py deleted file mode 100644 index 26b0aebf6..000000000 --- a/tests/test_lateon_pylate_reference.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import Any, Iterable - -import numpy as np -import pytest - -from fastembed import LateInteractionTextEmbedding -from fastembed.common.types import NumpyArray - - -LATEON_MODEL_NAME = "lightonai/LateOn" -REFERENCE_TEXTS = [ - "Hello World", - "Late interaction models compare token embeddings!", -] -RETRIEVAL_QUERIES = [ - "Which animal purrs?", - "What is the capital of France?", -] -RETRIEVAL_DOCUMENT_IDS = ["python", "cat", "paris"] -RETRIEVAL_DOCUMENTS = [ - "Python is a programming language used for machine learning.", - "Cats are small animals that often purr when they are happy.", - "Paris is the capital and largest city of France.", -] - - -def _as_numpy_arrays(embeddings: Iterable[NumpyArray]) -> list[np.ndarray]: - return [np.asarray(embedding, dtype=np.float32) for embedding in embeddings] - - -def _pylate_model(): - pylate_models = pytest.importorskip( - "pylate.models", reason="PyLate is required for the LateOn reference test" - ) - pytest.importorskip("torch", reason="PyLate reference inference requires PyTorch") - return pylate_models.ColBERT(model_name_or_path=LATEON_MODEL_NAME) - - -def _pylate_reference_embeddings(texts: list[str], is_query: bool) -> list[np.ndarray]: - model = _pylate_model() - embeddings = model.encode( - texts, - batch_size=2, - is_query=is_query, - show_progress_bar=False, - convert_to_numpy=True, - normalize_embeddings=True, - device="cpu", - ) - return _as_numpy_arrays(embeddings) - - -def _rerank( - queries_embeddings: list[np.ndarray], documents_embeddings: list[np.ndarray] -) -> list[list[Any]]: - pylate_rank = pytest.importorskip( - "pylate.rank", reason="PyLate is required for the LateOn retrieval reference test" - ) - return pylate_rank.rerank( - documents_ids=[RETRIEVAL_DOCUMENT_IDS] * len(queries_embeddings), - queries_embeddings=queries_embeddings, - documents_embeddings=[documents_embeddings] * len(queries_embeddings), - device="cpu", - ) - - -def _result_id(result: Any) -> str: - return result["id"] if isinstance(result, dict) else result.id - - -def _result_score(result: Any) -> float: - return result["score"] if isinstance(result, dict) else result.score - - -@pytest.mark.parametrize("is_query", [False, True]) -def test_lateon_matches_pylate_reference(is_query: bool) -> None: - pylate_embeddings = _pylate_reference_embeddings(REFERENCE_TEXTS, is_query=is_query) - - fastembed_model = LateInteractionTextEmbedding(LATEON_MODEL_NAME, threads=1) - fastembed_embeddings = _as_numpy_arrays( - fastembed_model.query_embed(REFERENCE_TEXTS) - if is_query - else fastembed_model.embed(REFERENCE_TEXTS, batch_size=2) - ) - - assert len(fastembed_embeddings) == len(pylate_embeddings) - for fastembed_embedding, pylate_embedding in zip(fastembed_embeddings, pylate_embeddings): - assert fastembed_embedding.shape == pylate_embedding.shape - assert np.allclose(fastembed_embedding, pylate_embedding, rtol=1e-3, atol=1e-4) - - -def test_lateon_retrieval_matches_pylate_reference() -> None: - pylate_model = _pylate_model() - pylate_query_embeddings = _as_numpy_arrays( - pylate_model.encode( - RETRIEVAL_QUERIES, - batch_size=2, - is_query=True, - show_progress_bar=False, - convert_to_numpy=True, - normalize_embeddings=True, - device="cpu", - ) - ) - pylate_document_embeddings = _as_numpy_arrays( - pylate_model.encode( - RETRIEVAL_DOCUMENTS, - batch_size=2, - is_query=False, - show_progress_bar=False, - convert_to_numpy=True, - normalize_embeddings=True, - device="cpu", - ) - ) - - fastembed_model = LateInteractionTextEmbedding(LATEON_MODEL_NAME, threads=1) - fastembed_query_embeddings = _as_numpy_arrays(fastembed_model.query_embed(RETRIEVAL_QUERIES)) - fastembed_document_embeddings = _as_numpy_arrays( - fastembed_model.embed(RETRIEVAL_DOCUMENTS, batch_size=2) - ) - - pylate_results = _rerank(pylate_query_embeddings, pylate_document_embeddings) - fastembed_results = _rerank(fastembed_query_embeddings, fastembed_document_embeddings) - - assert len(fastembed_results) == len(pylate_results) - for fastembed_query_results, pylate_query_results in zip(fastembed_results, pylate_results): - assert [_result_id(result) for result in fastembed_query_results] == [ - _result_id(result) for result in pylate_query_results - ] - assert np.allclose( - [_result_score(result) for result in fastembed_query_results], - [_result_score(result) for result in pylate_query_results], - rtol=1e-3, - atol=1e-3, - ) From 54867d2af7cf5bf3ee1977bcecfd8c8cf9acbd76 Mon Sep 17 00:00:00 2001 From: Oliver Mannion <125105+tekumara@users.noreply.github.com> Date: Mon, 1 Jun 2026 13:04:24 +1000 Subject: [PATCH 5/6] pylate canonical values script --- pylate_canonical_values.txt | 27 +++++ scripts/generate_lateon_canonical_values.py | 124 ++++++++++++++++++++ 2 files changed, 151 insertions(+) create mode 100644 pylate_canonical_values.txt create mode 100755 scripts/generate_lateon_canonical_values.py diff --git a/pylate_canonical_values.txt b/pylate_canonical_values.txt new file mode 100644 index 000000000..e50ddc4dd --- /dev/null +++ b/pylate_canonical_values.txt @@ -0,0 +1,27 @@ +# Generated from the PyLate reference implementation, not FastEmbed. +# +# Reference code: +# from pylate import models +# model = models.ColBERT(model_name_or_path='lightonai/LateOn') +# model.encode([text], is_query=False/True, convert_to_numpy=True, normalize_embeddings=True) +# text = 'Hello World' +# document_shape = (5, 128) +# query_shape = (5, 128) + +# CANONICAL_COLUMN_VALUES entry + "lightonai/LateOn": np.array( + [[ 0.00039, 0.00651, 0.0146 , 0.00346, 0.00244], + [-0.0029 , 0.00423, 0.00042, 0.02236, 0.00981], + [-0.0287 , 0.01159, 0.02401, -0.00312, -0.04338], + [-0.04709, 0.00209, 0.02174, -0.00381, -0.00608], + [-0.02461, -0.02876, 0.03014, -0.0035 , -0.00431]] + ), + +# CANONICAL_QUERY_VALUES entry + "lightonai/LateOn": np.array( + [[ 0.00202, -0.02634, 0.00685, 0.00993, 0.03093], + [-0.02321, -0.0226 , 0.00356, 0.02836, 0.01729], + [-0.01066, 0.00595, 0.02884, 0.00267, -0.10405], + [-0.10359, -0.06927, 0.03217, 0.05037, -0.03338], + [-0.02992, -0.03874, 0.10582, 0.06303, 0.05831]] + ), diff --git a/scripts/generate_lateon_canonical_values.py b/scripts/generate_lateon_canonical_values.py new file mode 100755 index 000000000..2fb137568 --- /dev/null +++ b/scripts/generate_lateon_canonical_values.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +"""Generate LateOn canonical test values from the PyLate reference implementation. + +This script prints the abridged document and query vectors used by +``tests/test_late_interaction_embeddings.py``. It intentionally uses PyLate, +not FastEmbed, so the generated values come from the original reference model. + +Example: + python scripts/generate_lateon_canonical_values.py +""" + +from __future__ import annotations + +import argparse +from typing import Sequence + +import numpy as np + + +DEFAULT_MODEL = "lightonai/LateOn" +DEFAULT_TEXT = "Hello World" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate canonical LateOn test vectors with PyLate." + ) + parser.add_argument("--model", default=DEFAULT_MODEL, help="PyLate/HF model name") + parser.add_argument("--text", default=DEFAULT_TEXT, help="Text to encode") + parser.add_argument( + "--rows", + type=int, + default=5, + help="Number of token rows to print from each embedding", + ) + parser.add_argument( + "--dims", + type=int, + default=5, + help="Number of dimensions to print from each token embedding", + ) + parser.add_argument( + "--precision", + type=int, + default=5, + help="Decimal precision for printed values", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device passed to PyLate encode, e.g. cpu or cuda", + ) + return parser.parse_args() + + +def load_pylate_model(model_name: str): + try: + from pylate import models + except ImportError as exc: + raise SystemExit( + "PyLate is required to generate reference values. " + "Install it with `pip install pylate`." + ) from exc + + return models.ColBERT(model_name_or_path=model_name) + + +def encode_reference(model, texts: Sequence[str], *, is_query: bool, device: str) -> np.ndarray: + embeddings = model.encode( + list(texts), + batch_size=1, + is_query=is_query, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=True, + device=device, + ) + return np.asarray(embeddings[0], dtype=np.float32) + + +def format_dict_entry(model_name: str, values: np.ndarray, *, precision: int) -> str: + array = np.array2string( + values, + precision=precision, + separator=", ", + suppress_small=False, + ) + # Indent nested array lines to match the style in tests/test_late_interaction_embeddings.py. + array = "\n".join(f" {line}" for line in array.splitlines()) + return f' "{model_name}": np.array(\n{array}\n ),' + + +def main() -> None: + args = parse_args() + model = load_pylate_model(args.model) + + document_embedding = encode_reference(model, [args.text], is_query=False, device=args.device) + query_embedding = encode_reference(model, [args.text], is_query=True, device=args.device) + + document_values = document_embedding[: args.rows, : args.dims] + query_values = query_embedding[: args.rows, : args.dims] + + print("# Generated from the PyLate reference implementation, not FastEmbed.") + print("#") + print("# Reference code:") + print("# from pylate import models") + print(f"# model = models.ColBERT(model_name_or_path={args.model!r})") + print( + "# model.encode([text], is_query=False/True, " + "convert_to_numpy=True, normalize_embeddings=True)" + ) + print(f"# text = {args.text!r}") + print(f"# document_shape = {tuple(document_embedding.shape)}") + print(f"# query_shape = {tuple(query_embedding.shape)}") + print() + print("# CANONICAL_COLUMN_VALUES entry") + print(format_dict_entry(args.model, document_values, precision=args.precision)) + print() + print("# CANONICAL_QUERY_VALUES entry") + print(format_dict_entry(args.model, query_values, precision=args.precision)) + + +if __name__ == "__main__": + main() From 02f9145c645767945188c90e93230abb156d61d6 Mon Sep 17 00:00:00 2001 From: Oliver Mannion <125105+tekumara@users.noreply.github.com> Date: Mon, 1 Jun 2026 13:04:42 +1000 Subject: [PATCH 6/6] remove pylate canonical values script --- pylate_canonical_values.txt | 27 ----- scripts/generate_lateon_canonical_values.py | 124 -------------------- 2 files changed, 151 deletions(-) delete mode 100644 pylate_canonical_values.txt delete mode 100755 scripts/generate_lateon_canonical_values.py diff --git a/pylate_canonical_values.txt b/pylate_canonical_values.txt deleted file mode 100644 index e50ddc4dd..000000000 --- a/pylate_canonical_values.txt +++ /dev/null @@ -1,27 +0,0 @@ -# Generated from the PyLate reference implementation, not FastEmbed. -# -# Reference code: -# from pylate import models -# model = models.ColBERT(model_name_or_path='lightonai/LateOn') -# model.encode([text], is_query=False/True, convert_to_numpy=True, normalize_embeddings=True) -# text = 'Hello World' -# document_shape = (5, 128) -# query_shape = (5, 128) - -# CANONICAL_COLUMN_VALUES entry - "lightonai/LateOn": np.array( - [[ 0.00039, 0.00651, 0.0146 , 0.00346, 0.00244], - [-0.0029 , 0.00423, 0.00042, 0.02236, 0.00981], - [-0.0287 , 0.01159, 0.02401, -0.00312, -0.04338], - [-0.04709, 0.00209, 0.02174, -0.00381, -0.00608], - [-0.02461, -0.02876, 0.03014, -0.0035 , -0.00431]] - ), - -# CANONICAL_QUERY_VALUES entry - "lightonai/LateOn": np.array( - [[ 0.00202, -0.02634, 0.00685, 0.00993, 0.03093], - [-0.02321, -0.0226 , 0.00356, 0.02836, 0.01729], - [-0.01066, 0.00595, 0.02884, 0.00267, -0.10405], - [-0.10359, -0.06927, 0.03217, 0.05037, -0.03338], - [-0.02992, -0.03874, 0.10582, 0.06303, 0.05831]] - ), diff --git a/scripts/generate_lateon_canonical_values.py b/scripts/generate_lateon_canonical_values.py deleted file mode 100755 index 2fb137568..000000000 --- a/scripts/generate_lateon_canonical_values.py +++ /dev/null @@ -1,124 +0,0 @@ -#!/usr/bin/env python3 -"""Generate LateOn canonical test values from the PyLate reference implementation. - -This script prints the abridged document and query vectors used by -``tests/test_late_interaction_embeddings.py``. It intentionally uses PyLate, -not FastEmbed, so the generated values come from the original reference model. - -Example: - python scripts/generate_lateon_canonical_values.py -""" - -from __future__ import annotations - -import argparse -from typing import Sequence - -import numpy as np - - -DEFAULT_MODEL = "lightonai/LateOn" -DEFAULT_TEXT = "Hello World" - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Generate canonical LateOn test vectors with PyLate." - ) - parser.add_argument("--model", default=DEFAULT_MODEL, help="PyLate/HF model name") - parser.add_argument("--text", default=DEFAULT_TEXT, help="Text to encode") - parser.add_argument( - "--rows", - type=int, - default=5, - help="Number of token rows to print from each embedding", - ) - parser.add_argument( - "--dims", - type=int, - default=5, - help="Number of dimensions to print from each token embedding", - ) - parser.add_argument( - "--precision", - type=int, - default=5, - help="Decimal precision for printed values", - ) - parser.add_argument( - "--device", - default="cpu", - help="Device passed to PyLate encode, e.g. cpu or cuda", - ) - return parser.parse_args() - - -def load_pylate_model(model_name: str): - try: - from pylate import models - except ImportError as exc: - raise SystemExit( - "PyLate is required to generate reference values. " - "Install it with `pip install pylate`." - ) from exc - - return models.ColBERT(model_name_or_path=model_name) - - -def encode_reference(model, texts: Sequence[str], *, is_query: bool, device: str) -> np.ndarray: - embeddings = model.encode( - list(texts), - batch_size=1, - is_query=is_query, - show_progress_bar=False, - convert_to_numpy=True, - normalize_embeddings=True, - device=device, - ) - return np.asarray(embeddings[0], dtype=np.float32) - - -def format_dict_entry(model_name: str, values: np.ndarray, *, precision: int) -> str: - array = np.array2string( - values, - precision=precision, - separator=", ", - suppress_small=False, - ) - # Indent nested array lines to match the style in tests/test_late_interaction_embeddings.py. - array = "\n".join(f" {line}" for line in array.splitlines()) - return f' "{model_name}": np.array(\n{array}\n ),' - - -def main() -> None: - args = parse_args() - model = load_pylate_model(args.model) - - document_embedding = encode_reference(model, [args.text], is_query=False, device=args.device) - query_embedding = encode_reference(model, [args.text], is_query=True, device=args.device) - - document_values = document_embedding[: args.rows, : args.dims] - query_values = query_embedding[: args.rows, : args.dims] - - print("# Generated from the PyLate reference implementation, not FastEmbed.") - print("#") - print("# Reference code:") - print("# from pylate import models") - print(f"# model = models.ColBERT(model_name_or_path={args.model!r})") - print( - "# model.encode([text], is_query=False/True, " - "convert_to_numpy=True, normalize_embeddings=True)" - ) - print(f"# text = {args.text!r}") - print(f"# document_shape = {tuple(document_embedding.shape)}") - print(f"# query_shape = {tuple(query_embedding.shape)}") - print() - print("# CANONICAL_COLUMN_VALUES entry") - print(format_dict_entry(args.model, document_values, precision=args.precision)) - print() - print("# CANONICAL_QUERY_VALUES entry") - print(format_dict_entry(args.model, query_values, precision=args.precision)) - - -if __name__ == "__main__": - main()