diff --git a/hackable_diffusion/lib/architecture/mlp.py b/hackable_diffusion/lib/architecture/mlp.py index 89bfa7c..cc4585a 100644 --- a/hackable_diffusion/lib/architecture/mlp.py +++ b/hackable_diffusion/lib/architecture/mlp.py @@ -78,6 +78,7 @@ def __call__( self, x: DataArray, conditioning_embeddings: dict[ConditioningMechanism, Float['batch ...']], + *, is_training: bool, ) -> DataArray: x_emb = jnp.reshape(x, shape=(x.shape[0], -1)) @@ -92,7 +93,7 @@ def __call__( dropout_rate=self.dropout_rate, dtype=self.dtype, name='PreprocessMLP', - )(x_emb, is_training) + )(x_emb, is_training=is_training) # The conditioning was already processed by the `conditioning_encoder`, so # here we just need to concatenate it with the `x`. @@ -125,7 +126,7 @@ def __call__( dtype=self.dtype, zero_init_output=self.zero_init_output, name='PostprocessMLP', - )(emb, is_training) + )(emb, is_training=is_training) output = jnp.reshape(output, shape=x.shape) output = utils.optional_bf16_to_fp32(output) diff --git a/hackable_diffusion/lib/architecture/mlp_blocks.py b/hackable_diffusion/lib/architecture/mlp_blocks.py index 7ef162c..29189ae 100644 --- a/hackable_diffusion/lib/architecture/mlp_blocks.py +++ b/hackable_diffusion/lib/architecture/mlp_blocks.py @@ -46,7 +46,7 @@ class MLP(nn.Module): @nn.compact @typechecked def __call__( - self, x: Float['batch num_inputs'], is_training: bool + self, x: Float['batch num_inputs'], *, is_training: bool ) -> Float['batch num_features']: """Applies MLP blocks to the input tensor. diff --git a/hackable_diffusion/lib/architecture/unet.py b/hackable_diffusion/lib/architecture/unet.py index fd6d499..105009b 100644 --- a/hackable_diffusion/lib/architecture/unet.py +++ b/hackable_diffusion/lib/architecture/unet.py @@ -164,6 +164,7 @@ def __call__( self, x: Float["batch height width channels"], conditioning_embeddings: dict[ConditioningMechanism, Float["batch ..."]], + *, is_training: bool, ) -> Float["batch height width output_channels"]: diff --git a/hackable_diffusion/lib/architecture/unet_blocks.py b/hackable_diffusion/lib/architecture/unet_blocks.py index 3199abb..134b465 100644 --- a/hackable_diffusion/lib/architecture/unet_blocks.py +++ b/hackable_diffusion/lib/architecture/unet_blocks.py @@ -266,6 +266,7 @@ def __call__( self, x: Float["batch height width channels"], cross_attention_emb: Float["batch seq cond_dim2"] | None, + *, is_training: bool, ) -> Float["batch height width channels"]: skip = x