From 41f02e51017f312db663d2514718783019285363 Mon Sep 17 00:00:00 2001 From: Alexandre Galashov Date: Thu, 19 Feb 2026 06:20:15 -0800 Subject: [PATCH] Add `AdditiveSequenceEmbedding` PiperOrigin-RevId: 872366209 --- hackable_diffusion/lib/architecture/mlp.py | 5 +-- .../lib/architecture/mlp_blocks.py | 2 +- .../lib/architecture/sequence_embedders.py | 23 +++++++++++++ .../architecture/sequence_embedders_test.py | 34 ++++++++++++++++++- hackable_diffusion/lib/architecture/unet.py | 1 + .../lib/architecture/unet_blocks.py | 1 + 6 files changed, 62 insertions(+), 4 deletions(-) diff --git a/hackable_diffusion/lib/architecture/mlp.py b/hackable_diffusion/lib/architecture/mlp.py index 89bfa7c..cc4585a 100644 --- a/hackable_diffusion/lib/architecture/mlp.py +++ b/hackable_diffusion/lib/architecture/mlp.py @@ -78,6 +78,7 @@ def __call__( self, x: DataArray, conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + *, is_training: bool, ) -> DataArray: x_emb = jnp.reshape(x, shape=(x.shape[0], -1)) @@ -92,7 +93,7 @@ def __call__( dropout_rate=self.dropout_rate, dtype=self.dtype, name='PreprocessMLP', - )(x_emb, is_training) + )(x_emb, is_training=is_training) # The conditioning was already processed by the `conditioning_encoder`, so # here we just need to concatenate it with the `x`. @@ -125,7 +126,7 @@ def __call__( dtype=self.dtype, zero_init_output=self.zero_init_output, name='PostprocessMLP', - )(emb, is_training) + )(emb, is_training=is_training) output = jnp.reshape(output, shape=x.shape) output = utils.optional_bf16_to_fp32(output) diff --git a/hackable_diffusion/lib/architecture/mlp_blocks.py b/hackable_diffusion/lib/architecture/mlp_blocks.py index 7ef162c..29189ae 100644 --- a/hackable_diffusion/lib/architecture/mlp_blocks.py +++ b/hackable_diffusion/lib/architecture/mlp_blocks.py @@ -46,7 +46,7 @@ class MLP(nn.Module): @nn.compact @typechecked def __call__( - self, x: Float['batch num_inputs'], is_training: bool + self, x: Float['batch num_inputs'], *, is_training: bool ) -> Float['batch num_features']: """Applies MLP blocks to the input tensor. diff --git a/hackable_diffusion/lib/architecture/sequence_embedders.py b/hackable_diffusion/lib/architecture/sequence_embedders.py index 8bf1c36..fc5cab6 100644 --- a/hackable_diffusion/lib/architecture/sequence_embedders.py +++ b/hackable_diffusion/lib/architecture/sequence_embedders.py @@ -186,3 +186,26 @@ def __call__( out = jnp.asarray(jnp.concatenate(result, axis=-1), x.dtype) return out + + +class AdditiveSequenceEmbedding(nn.Module): + """Learnable additive sequence positional embedding.""" + + num_features: int + + def setup(self): + if self.num_features <= 0: + raise ValueError("Number of features must be positive.") + + @nn.compact + @typechecked + def __call__( + self, x: Num["batch *#data_shape"] + ) -> Float["batch *#data_shape"]: + pos_embed = self.param( + "PositionalEmbeddingTensor", + nn.initializers.normal(stddev=0.02), + (1, *x.shape[1:]), + x.dtype, + ) + return x + pos_embed diff --git a/hackable_diffusion/lib/architecture/sequence_embedders_test.py b/hackable_diffusion/lib/architecture/sequence_embedders_test.py index 9f06858..1c0dc31 100644 --- a/hackable_diffusion/lib/architecture/sequence_embedders_test.py +++ b/hackable_diffusion/lib/architecture/sequence_embedders_test.py @@ -38,7 +38,11 @@ def _get_invalid_num_features_params(): """Generates parameters for testing invalid num_features.""" params = [] - modes = ["sinusoidal_embedding", "random_fourier_embedding"] + modes = [ + "sinusoidal_embedding", + "random_fourier_embedding", + "additive_embedding", + ] feature_values = [ ("default", INVALID_INT), ("zero", 0), @@ -102,6 +106,10 @@ def test_sequence_embedding_raises_error_on_invalid_num_features( module = sequence_embedders.RandomFourierSequenceEmbedding( num_features=num_features ) + elif mode == "additive_embedding": + module = sequence_embedders.AdditiveSequenceEmbedding( + num_features=num_features + ) else: self.fail(f"Unknown mode: {mode}") inputs = jnp.arange(self.batch_size) @@ -244,6 +252,30 @@ def test_rope_embedding_has_no_params(self): variables = module.init(self.rng, x_rope) self.assertEmpty(variables) + # MARK: AdditiveSequenceEmbedding tests + + def test_additive_embedding_output_shape(self): + """Tests the output shape of AdditiveSequenceEmbedding.""" + module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim) + variables = module.init({"params": self.rng}, self.x) + output = module.apply(variables, self.x) + self.assertEqual(output.shape, self.x.shape) + + def test_additive_embedding_params_are_updated(self): + """Tests that AdditiveSequenceEmbedding params are updated.""" + module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim) + variables = module.init({"params": self.rng}, self.x) + initial_params = variables["params"] + + def loss_fn(params): + output = module.apply({"params": params}, self.x) + return jnp.sum(output) + + grads = jax.grad(loss_fn)(initial_params) + + # Check that the gradients are not zero. + self.assertFalse(jnp.allclose(grads["PositionalEmbeddingTensor"], 0.0)) + if __name__ == "__main__": absltest.main() diff --git a/hackable_diffusion/lib/architecture/unet.py b/hackable_diffusion/lib/architecture/unet.py index fd6d499..105009b 100644 --- a/hackable_diffusion/lib/architecture/unet.py +++ b/hackable_diffusion/lib/architecture/unet.py @@ -164,6 +164,7 @@ def __call__( self, x: Float["batch height width channels"], conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + *, is_training: bool, ) -> Float["batch height width output_channels"]: diff --git a/hackable_diffusion/lib/architecture/unet_blocks.py b/hackable_diffusion/lib/architecture/unet_blocks.py index 3199abb..134b465 100644 --- a/hackable_diffusion/lib/architecture/unet_blocks.py +++ b/hackable_diffusion/lib/architecture/unet_blocks.py @@ -266,6 +266,7 @@ def __call__( self, x: Float["batch height width channels"], cross_attention_emb: Float["batch seq cond_dim2"] | None, + *, is_training: bool, ) -> Float["batch height width channels"]: skip = x