diff --git a/hackable_diffusion/lib/architecture/dit.py b/hackable_diffusion/lib/architecture/dit.py new file mode 100644 index 0000000..d784a04 --- /dev/null +++ b/hackable_diffusion/lib/architecture/dit.py @@ -0,0 +1,160 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DiT architecture.""" + +import flax.linen as nn +from hackable_diffusion.lib import hd_typing +from hackable_diffusion.lib import utils +from hackable_diffusion.lib.architecture import arch_typing +from hackable_diffusion.lib.architecture import dit_blocks +from hackable_diffusion.lib.architecture import normalization +from hackable_diffusion.lib.architecture import sequence_embedders +from hackable_diffusion.lib.hd_typing import typechecked # pylint: disable=g-multiple-import,g-importing-member +import jax.numpy as jnp + +################################################################################ +# MARK: Common types +################################################################################ + +DType = hd_typing.DType +Float = hd_typing.Float + +ConditionalBackbone = arch_typing.ConditionalBackbone +ConditioningMechanism = arch_typing.ConditioningMechanism +RoPEPositionType = arch_typing.RoPEPositionType + +################################################################################ +# MARK: DiT +################################################################################ + + +class DiT(ConditionalBackbone): + """Diffusion Transformer (DiT) backbone. + + Based on + https://arxiv.org/abs/2212.09748 + https://github.com/facebookresearch/DiT + + Main architectural differences from official implementation: + - Use learned positional embeddings instead of 2D sinusoidal embeddings. + + + Attributes: + patch_size: The size of the patches. + hidden_size: The dimension of the embedding. + depth: The number of DiT blocks. + num_heads: The number of attention heads. + mlp_ratio: The ratio of the hidden dimension in the MLP to the input + dimension. + dropout_rate: The dropout rate. Not used since DiT does not use dropout. + attention_use_rope: Whether to use rotary positional embeddings. + attention_rope_position_type: The type of rotary positional embeddings. + normalization_type: The normalization method to use. + normalization_num_groups: The number of groups for group normalization. + output_channels: The number of output channels. If None, defaults to the + number of input channels. + dtype: The data type of the computation. + """ + + patch_size: int + hidden_size: int + depth: int + num_heads: int + mlp_ratio: float = 4.0 + dropout_rate: float = 0.0 + attention_use_rope: bool = False + attention_rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE + normalization_type: normalization.NormalizationType = ( + normalization.NormalizationType.RMS_NORM + ) + normalization_num_groups: int | None = None + output_channels: int | None = None + dtype: DType = jnp.float32 + + def setup(self): + self.patch_embedder = dit_blocks.PatchEmbedder( + patch_size=self.patch_size, + hidden_size=self.hidden_size, + dtype=self.dtype, + ) + + self.norm_factory = normalization.NormalizationLayerFactory( + normalization_method=self.normalization_type, + num_groups=self.normalization_num_groups, + dtype=self.dtype, + ) + + @nn.compact + @typechecked + def __call__( + self, + x: Float["batch height width channels"], + conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + *, + is_training: bool, + ) -> Float["batch height width output_channels"]: + b, h, w, input_channels = x.shape + c = conditioning_embeddings.get(ConditioningMechanism.ADAPTIVE_NORM) + if c is None: + raise ValueError( + "ADAPTIVE_NORM conditioning must be provided in" + " conditioning_embeddings for DiT. Available keys in" + f" conditioning_embeddings: {conditioning_embeddings.keys()}" + ) + + x_patches = self.patch_embedder(x) + seq_dim = (h // self.patch_size) * (w // self.patch_size) + assert x_patches.shape == (b, seq_dim, self.hidden_size) + # x_patches is (B, L, D) + + # Add Positional Embeddings + if not self.attention_use_rope: + # Learnable positional embedding + # We rely on the fact that if the shape changes, this might fail or + # re-init if not handled carefully in JAX/Flax. Standard Flax usage + # assumes fixed shape or careful handling. For this implementation, we + # assume fixed max length or just current length. + x_patches = sequence_embedders.AdditiveSequenceEmbedding( + num_features=self.hidden_size + )(x_patches) + + # Blocks + for i in range(self.depth): + x_patches = dit_blocks.DiTBlock( + norm_factory=self.norm_factory, + hidden_size=self.hidden_size, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + use_rope=self.attention_use_rope, + rope_position_type=self.attention_rope_position_type, + dtype=self.dtype, + name=f"DiTBlock_{i}", + )(x_patches, c, is_training=is_training) + + # Final Layer + num_output_channels = ( + input_channels if self.output_channels is None else self.output_channels + ) + + x_out = dit_blocks.FinalLayer( + norm_factory=self.norm_factory, + patch_size=self.patch_size, + out_channels=num_output_channels, + dtype=self.dtype, + name="FinalLayer", + )(x_patches, c) + + x_out = utils.optional_bf16_to_fp32(x_out) + return x_out diff --git a/hackable_diffusion/lib/architecture/dit_blocks.py b/hackable_diffusion/lib/architecture/dit_blocks.py new file mode 100644 index 0000000..5bbc274 --- /dev/null +++ b/hackable_diffusion/lib/architecture/dit_blocks.py @@ -0,0 +1,243 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DiT building blocks.""" + +import einops +import flax.linen as nn +from hackable_diffusion.lib import hd_typing +from hackable_diffusion.lib import utils +from hackable_diffusion.lib.architecture import arch_typing +from hackable_diffusion.lib.architecture import attention +from hackable_diffusion.lib.architecture import mlp_blocks +from hackable_diffusion.lib.architecture import normalization +from hackable_diffusion.lib.hd_typing import typechecked # pylint: disable=g-multiple-import,g-importing-member +import jax.numpy as jnp + +################################################################################ +# MARK: Type aliases +################################################################################ + +DType = hd_typing.DType +Float = hd_typing.Float + +NormalizationLayerFactory = normalization.NormalizationLayerFactory +RoPEPositionType = arch_typing.RoPEPositionType + + +################################################################################ +# MARK: Patch Embedder +################################################################################ + + +class PatchEmbedder(nn.Module): + """Patch embedding layer. + + Splits the image into patches and embeds them. + + Attributes: + patch_size: The size of the patches. + hidden_size: The dimension of the embedding. + dtype: The data type of the computation. + """ + + patch_size: int + hidden_size: int + dtype: DType = jnp.float32 + + @nn.compact + @typechecked + def __call__( + self, x: Float["batch height width channels"] + ) -> Float["batch sequence hidden_size"]: + b, h, w, _ = x.shape + if h % self.patch_size != 0 or w % self.patch_size != 0: + raise ValueError( + f"Image dimensions ({h}, {w}) must be divisible by patch size" + f" ({self.patch_size})." + ) + + x = nn.Conv( + features=self.hidden_size, + kernel_size=(self.patch_size, self.patch_size), + strides=(self.patch_size, self.patch_size), + padding="VALID", + dtype=self.dtype, + name="PatchEmbedder_Conv", + )(x) + # x is now (B, H//P, W//P, D) + assert x.shape == ( + b, + h // self.patch_size, + w // self.patch_size, + self.hidden_size, + ) + + return x.reshape(b, -1, self.hidden_size) + + +################################################################################ +# MARK: DiT Block +################################################################################ + + +class DiTBlock(nn.Module): + """Diffusion Transformer Block. + + Attributes: + norm_factory: Factory for creating normalization layers. + hidden_size: The dimension of the hidden state. + num_heads: The number of attention heads. + mlp_ratio: The ratio of the hidden dimension in the MLP to the input + dimension. + use_rope: Whether to use rotary positional embeddings. + rope_position_type: The type of rotary positional embeddings. + dtype: The data type of the computation. + """ + + norm_factory: NormalizationLayerFactory + hidden_size: int + num_heads: int + mlp_ratio: float = 4.0 + use_rope: bool = False + rope_position_type: RoPEPositionType = RoPEPositionType.SQUARE + dtype: DType = jnp.float32 + + def setup(self): + if not self.mlp_ratio > 0: + raise ValueError("MLP ratio must be positive.") + mlp_hidden_dim = int(self.hidden_size * self.mlp_ratio) + if not mlp_hidden_dim > 0: + raise ValueError("MLP hidden dimension must be positive.") + + self.norm = self.norm_factory.conditional_norm_factory() + self.attn = attention.MultiHeadAttention( + num_heads=self.num_heads, + use_rope=self.use_rope, + rope_position_type=self.rope_position_type, + zero_init_output=True, + dtype=self.dtype, + ) + # self.norm2 = self.norm_factory.conditional_norm_factory() + self.mlp = mlp_blocks.MLP( + hidden_sizes=[mlp_hidden_dim], + output_size=self.hidden_size, + activation="gelu", + activate_final=False, + zero_init_output=True, + dtype=self.dtype, + name="MLP", + ) + # Part of AdaLNZero architecture + # (scale/shift taken care of by the Conditional Normalizations) + self.attn_gate = nn.Dense( + features=self.hidden_size, + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + name="AttnGate", + ) + self.mlp_gate = nn.Dense( + features=self.hidden_size, + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + name="MLPGate", + ) + + @nn.compact + @typechecked + def __call__( + self, + x: Float["batch sequence hidden_size"], + c: Float["batch cond_dim"], + *, + is_training: bool, + ) -> Float["batch sequence hidden_size"]: + pad_to_seq_axis = lambda t: einops.rearrange( + utils.bcast_right(einops.rearrange(t, "b ... c -> b c ..."), x.ndim), + "b c ... -> b ... c", + ) + attn_gate = pad_to_seq_axis(self.attn_gate(c)) + mlp_gate = pad_to_seq_axis(self.mlp_gate(c)) + x = x + attn_gate * self.attn(self.norm(x, c), c=None) + + # Re-use MLP which assumes (B, Dim) structure. + def _seq_mlp(x): + b, t, d = x.shape + x = jnp.reshape(x, (b * t, d)) + y = self.mlp(x, is_training=is_training) + y = y.reshape(b, t, d) + return y + + return x + mlp_gate * _seq_mlp(self.norm(x, c)) + + +################################################################################ +# MARK: Final Layer +################################################################################ + + +class FinalLayer(nn.Module): + """Final layer of DiT. + + Attributes: + norm_factory: Factory for creating normalization layers. + patch_size: The size of the patches. + out_channels: The number of output channels. + dtype: The data type of the computation. + """ + + norm_factory: NormalizationLayerFactory + patch_size: int + out_channels: int + dtype: DType = jnp.float32 + + def setup(self): + self.norm_final = self.norm_factory.conditional_norm_factory() + self.linear = nn.Dense( + features=self.patch_size * self.patch_size * self.out_channels, + kernel_init=nn.initializers.zeros_init(), + bias_init=nn.initializers.zeros_init(), + dtype=self.dtype, + name="Final_Linear", + ) + + @nn.compact + @typechecked + def __call__( + self, + x: Float["batch sequence hidden_size"], + c: Float["batch cond_dim"], + ) -> Float["batch height width out_channels"]: + x = self.norm_final(x, c) + x = self.linear(x) + + # Unpatchify, assuming square image (and square patches) + b, l, _ = x.shape + h = w = int(jnp.sqrt(l)) + if (h * w) != l: + raise ValueError( + f"Number of patches ({h}x{w}) is not divisible by sequence length" + f" ({l})." + ) + + # x is (B, H*W, P*P*C) + x = x.reshape(b, h, w, self.patch_size, self.patch_size, self.out_channels) + # (B, H, W, P, P, C) -> (B, H, P, W, P, C) -> (B, H*P, W*P, C) + x = jnp.einsum("bhwpqc->bhpwqc", x) + x = x.reshape( + b, h * self.patch_size, w * self.patch_size, self.out_channels + ) + return x diff --git a/hackable_diffusion/lib/architecture/dit_blocks_test.py b/hackable_diffusion/lib/architecture/dit_blocks_test.py new file mode 100644 index 0000000..c5a3f59 --- /dev/null +++ b/hackable_diffusion/lib/architecture/dit_blocks_test.py @@ -0,0 +1,132 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for dit_blocks.""" + +from hackable_diffusion.lib.architecture import dit_blocks +from hackable_diffusion.lib.architecture import normalization +from hackable_diffusion.lib.architecture import test_utils +import jax +import jax.numpy as jnp +from absl.testing import absltest +from absl.testing import parameterized + + +class DiTBlocksTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(0) + self.is_training = False + + def test_patch_embedder(self): + x = jnp.ones((1, 32, 32, 3)) + layer = dit_blocks.PatchEmbedder(patch_size=4, hidden_size=16) + variables = layer.init(self.key, x) + # Checking that the variables have expected shapes. + variables_shapes = test_utils.get_pytree_shapes(variables) + expected_variables_shapes = { + 'params': { + 'PatchEmbedder_Conv': { + 'kernel': (4, 4, 3, 16), + 'bias': (16,), + } + } + } + self.assertDictEqual(expected_variables_shapes, variables_shapes) + out = layer.apply(variables, x) + # 32/4 = 8. 8*8 = 64 patches. + self.assertEqual(out.shape, (1, 64, 16)) + + def test_dit_block(self): + x = jnp.ones((1, 64, 16)) + c = jnp.ones((1, 32)) # Condition + norm_factory = normalization.NormalizationLayerFactory( + normalization_method=normalization.NormalizationType.RMS_NORM + ) + layer = dit_blocks.DiTBlock( + norm_factory=norm_factory, + hidden_size=16, + num_heads=4, + mlp_ratio=4.0, + ) + variables = layer.init( + {'params': self.key, 'dropout': self.key}, + x, + c, + is_training=self.is_training, + ) + # Checking that the variables have expected shapes. + variables_shapes = test_utils.get_pytree_shapes(variables) + expected_variables_shapes = { + 'params': { + 'AttnGate': {'bias': (16,), 'kernel': (32, 16)}, + 'ConditionalNorm': { + 'Dense_0': {'bias': (32,), 'kernel': (32, 32)}, + 'RMSNorm_0': {'scale': (16,)}, + }, + 'MLP': { + 'Dense_Hidden_0': {'bias': (64,), 'kernel': (16, 64)}, + 'Dense_Output': {'bias': (16,), 'kernel': (64, 16)}, + }, + 'MLPGate': {'bias': (16,), 'kernel': (32, 16)}, + 'attn': { + 'Dense_K': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_Output': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_Q': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_V': {'bias': (16,), 'kernel': (16, 16)}, + }, + } + } + self.assertDictEqual(expected_variables_shapes, variables_shapes) + out = layer.apply(variables, x, c, is_training=self.is_training) + self.assertEqual(out.shape, (1, 64, 16)) + + def test_final_layer(self): + x = jnp.ones((1, 64, 16)) + c = jnp.ones((1, 32)) + norm_factory = normalization.NormalizationLayerFactory( + normalization_method=normalization.NormalizationType.RMS_NORM + ) + layer = dit_blocks.FinalLayer( + norm_factory=norm_factory, patch_size=4, out_channels=3 + ) + variables = layer.init(self.key, x, c) + # Checking that the variables have expected shapes. + variables_shapes = test_utils.get_pytree_shapes(variables) + expected_variables_shapes = { + 'params': { + 'ConditionalNorm': { + 'RMSNorm_0': { + 'scale': (16,), + }, + 'Dense_0': { + 'kernel': (32, 32), + 'bias': (32,), + }, + }, + 'Final_Linear': { + 'kernel': (16, 4 * 4 * 3), + 'bias': (4 * 4 * 3,), + }, + } + } + self.assertDictEqual(expected_variables_shapes, variables_shapes) + out = layer.apply(variables, x, c) + # H = sqrt(64) * 4 = 8 * 4 = 32. + self.assertEqual(out.shape, (1, 32, 32, 3)) + + +if __name__ == '__main__': + absltest.main() diff --git a/hackable_diffusion/lib/architecture/dit_test.py b/hackable_diffusion/lib/architecture/dit_test.py new file mode 100644 index 0000000..211ef81 --- /dev/null +++ b/hackable_diffusion/lib/architecture/dit_test.py @@ -0,0 +1,116 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for dit.""" + +from hackable_diffusion.lib.architecture import arch_typing +from hackable_diffusion.lib.architecture import dit +from hackable_diffusion.lib.architecture import normalization +from hackable_diffusion.lib.architecture import test_utils +import jax +import jax.numpy as jnp +from absl.testing import absltest +from absl.testing import parameterized + + +class DiTTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.key = jax.random.PRNGKey(0) + self.is_training = False + + def test_dit(self): + x = jnp.ones((1, 32, 32, 3)) + c = jnp.ones((1, 32)) + cond_embeddings = {arch_typing.ConditioningMechanism.ADAPTIVE_NORM: c} + + model = dit.DiT( + patch_size=4, + hidden_size=16, + depth=2, + num_heads=4, + mlp_ratio=4.0, + normalization_type=normalization.NormalizationType.RMS_NORM, + ) + + variables = model.init( + {'params': self.key, 'dropout': self.key}, + x, + cond_embeddings, + is_training=self.is_training, + ) + variables_shapes = test_utils.get_pytree_shapes(variables) + expected_variables_shapes = { + 'params': { + 'AdditiveSequenceEmbedding_0': { + 'PositionalEmbeddingTensor': (1, 64, 16) + }, + 'DiTBlock_0': { + 'AttnGate': {'bias': (16,), 'kernel': (32, 16)}, + 'ConditionalNorm': { + 'Dense_0': {'bias': (32,), 'kernel': (32, 32)}, + 'RMSNorm_0': {'scale': (16,)}, + }, + 'MLP': { + 'Dense_Hidden_0': {'bias': (64,), 'kernel': (16, 64)}, + 'Dense_Output': {'bias': (16,), 'kernel': (64, 16)}, + }, + 'MLPGate': {'bias': (16,), 'kernel': (32, 16)}, + 'attn': { + 'Dense_K': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_Output': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_Q': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_V': {'bias': (16,), 'kernel': (16, 16)}, + }, + }, + 'DiTBlock_1': { + 'AttnGate': {'bias': (16,), 'kernel': (32, 16)}, + 'ConditionalNorm': { + 'Dense_0': {'bias': (32,), 'kernel': (32, 32)}, + 'RMSNorm_0': {'scale': (16,)}, + }, + 'MLP': { + 'Dense_Hidden_0': {'bias': (64,), 'kernel': (16, 64)}, + 'Dense_Output': {'bias': (16,), 'kernel': (64, 16)}, + }, + 'MLPGate': {'bias': (16,), 'kernel': (32, 16)}, + 'attn': { + 'Dense_K': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_Output': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_Q': {'bias': (16,), 'kernel': (16, 16)}, + 'Dense_V': {'bias': (16,), 'kernel': (16, 16)}, + }, + }, + 'FinalLayer': { + 'ConditionalNorm': { + 'Dense_0': {'bias': (32,), 'kernel': (32, 32)}, + 'RMSNorm_0': {'scale': (16,)}, + }, + 'Final_Linear': {'bias': (48,), 'kernel': (16, 48)}, + }, + 'patch_embedder': { + 'PatchEmbedder_Conv': {'bias': (16,), 'kernel': (4, 4, 3, 16)} + }, + } + } + self.assertDictEqual(expected_variables_shapes, variables_shapes) + out = model.apply( + variables, x, cond_embeddings, is_training=self.is_training + ) + self.assertEqual(out.shape, (1, 32, 32, 3)) + + +if __name__ == '__main__': + absltest.main() diff --git a/hackable_diffusion/lib/architecture/mlp.py b/hackable_diffusion/lib/architecture/mlp.py index 89bfa7c..cc4585a 100644 --- a/hackable_diffusion/lib/architecture/mlp.py +++ b/hackable_diffusion/lib/architecture/mlp.py @@ -78,6 +78,7 @@ def __call__( self, x: DataArray, conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + *, is_training: bool, ) -> DataArray: x_emb = jnp.reshape(x, shape=(x.shape[0], -1)) @@ -92,7 +93,7 @@ def __call__( dropout_rate=self.dropout_rate, dtype=self.dtype, name='PreprocessMLP', - )(x_emb, is_training) + )(x_emb, is_training=is_training) # The conditioning was already processed by the `conditioning_encoder`, so # here we just need to concatenate it with the `x`. @@ -125,7 +126,7 @@ def __call__( dtype=self.dtype, zero_init_output=self.zero_init_output, name='PostprocessMLP', - )(emb, is_training) + )(emb, is_training=is_training) output = jnp.reshape(output, shape=x.shape) output = utils.optional_bf16_to_fp32(output) diff --git a/hackable_diffusion/lib/architecture/mlp_blocks.py b/hackable_diffusion/lib/architecture/mlp_blocks.py index 7ef162c..29189ae 100644 --- a/hackable_diffusion/lib/architecture/mlp_blocks.py +++ b/hackable_diffusion/lib/architecture/mlp_blocks.py @@ -46,7 +46,7 @@ class MLP(nn.Module): @nn.compact @typechecked def __call__( - self, x: Float['batch num_inputs'], is_training: bool + self, x: Float['batch num_inputs'], *, is_training: bool ) -> Float['batch num_features']: """Applies MLP blocks to the input tensor. diff --git a/hackable_diffusion/lib/architecture/sequence_embedders.py b/hackable_diffusion/lib/architecture/sequence_embedders.py index 8bf1c36..fc5cab6 100644 --- a/hackable_diffusion/lib/architecture/sequence_embedders.py +++ b/hackable_diffusion/lib/architecture/sequence_embedders.py @@ -186,3 +186,26 @@ def __call__( out = jnp.asarray(jnp.concatenate(result, axis=-1), x.dtype) return out + + +class AdditiveSequenceEmbedding(nn.Module): + """Learnable additive sequence positional embedding.""" + + num_features: int + + def setup(self): + if self.num_features <= 0: + raise ValueError("Number of features must be positive.") + + @nn.compact + @typechecked + def __call__( + self, x: Num["batch *#data_shape"] + ) -> Float["batch *#data_shape"]: + pos_embed = self.param( + "PositionalEmbeddingTensor", + nn.initializers.normal(stddev=0.02), + (1, *x.shape[1:]), + x.dtype, + ) + return x + pos_embed diff --git a/hackable_diffusion/lib/architecture/sequence_embedders_test.py b/hackable_diffusion/lib/architecture/sequence_embedders_test.py index 9f06858..1c0dc31 100644 --- a/hackable_diffusion/lib/architecture/sequence_embedders_test.py +++ b/hackable_diffusion/lib/architecture/sequence_embedders_test.py @@ -38,7 +38,11 @@ def _get_invalid_num_features_params(): """Generates parameters for testing invalid num_features.""" params = [] - modes = ["sinusoidal_embedding", "random_fourier_embedding"] + modes = [ + "sinusoidal_embedding", + "random_fourier_embedding", + "additive_embedding", + ] feature_values = [ ("default", INVALID_INT), ("zero", 0), @@ -102,6 +106,10 @@ def test_sequence_embedding_raises_error_on_invalid_num_features( module = sequence_embedders.RandomFourierSequenceEmbedding( num_features=num_features ) + elif mode == "additive_embedding": + module = sequence_embedders.AdditiveSequenceEmbedding( + num_features=num_features + ) else: self.fail(f"Unknown mode: {mode}") inputs = jnp.arange(self.batch_size) @@ -244,6 +252,30 @@ def test_rope_embedding_has_no_params(self): variables = module.init(self.rng, x_rope) self.assertEmpty(variables) + # MARK: AdditiveSequenceEmbedding tests + + def test_additive_embedding_output_shape(self): + """Tests the output shape of AdditiveSequenceEmbedding.""" + module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim) + variables = module.init({"params": self.rng}, self.x) + output = module.apply(variables, self.x) + self.assertEqual(output.shape, self.x.shape) + + def test_additive_embedding_params_are_updated(self): + """Tests that AdditiveSequenceEmbedding params are updated.""" + module = sequence_embedders.AdditiveSequenceEmbedding(num_features=self.dim) + variables = module.init({"params": self.rng}, self.x) + initial_params = variables["params"] + + def loss_fn(params): + output = module.apply({"params": params}, self.x) + return jnp.sum(output) + + grads = jax.grad(loss_fn)(initial_params) + + # Check that the gradients are not zero. + self.assertFalse(jnp.allclose(grads["PositionalEmbeddingTensor"], 0.0)) + if __name__ == "__main__": absltest.main() diff --git a/hackable_diffusion/lib/architecture/test_utils.py b/hackable_diffusion/lib/architecture/test_utils.py new file mode 100644 index 0000000..a57f740 --- /dev/null +++ b/hackable_diffusion/lib/architecture/test_utils.py @@ -0,0 +1,22 @@ +# Copyright 2026 Hackable Diffusion Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Reusable functions for testing.""" + +import jax + + +def get_pytree_shapes(pytree): + """Recursively extracts the shape of every leaf in a PyTree.""" + return jax.tree_util.tree_map(lambda x: getattr(x, 'shape', None), pytree) diff --git a/hackable_diffusion/lib/architecture/unet.py b/hackable_diffusion/lib/architecture/unet.py index fd6d499..105009b 100644 --- a/hackable_diffusion/lib/architecture/unet.py +++ b/hackable_diffusion/lib/architecture/unet.py @@ -164,6 +164,7 @@ def __call__( self, x: Float["batch height width channels"], conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + *, is_training: bool, ) -> Float["batch height width output_channels"]: diff --git a/hackable_diffusion/lib/architecture/unet_blocks.py b/hackable_diffusion/lib/architecture/unet_blocks.py index 3199abb..134b465 100644 --- a/hackable_diffusion/lib/architecture/unet_blocks.py +++ b/hackable_diffusion/lib/architecture/unet_blocks.py @@ -266,6 +266,7 @@ def __call__( self, x: Float["batch height width channels"], cross_attention_emb: Float["batch seq cond_dim2"] | None, + *, is_training: bool, ) -> Float["batch height width channels"]: skip = x