diff --git a/init2winit/model_lib/mdlm_rope_nanodo.py b/init2winit/model_lib/mdlm_rope_nanodo.py index 97815764..921c6ef9 100644 --- a/init2winit/model_lib/mdlm_rope_nanodo.py +++ b/init2winit/model_lib/mdlm_rope_nanodo.py @@ -223,17 +223,33 @@ def _compute_elbo(self, params, batch, rng, train): total_loss = jnp.sum(weighted_loss_BxL) if pad_weights is not None: num_tokens = jnp.sum(pad_weights) + effective_num_tokens = jnp.sum( + is_masked_BxL.astype(jnp.float32) * pad_weights + ) else: num_tokens = jnp.array(B * L, dtype=jnp.float32) - return total_loss / (num_tokens + self.hps['epsilon']) + effective_num_tokens = jnp.sum(is_masked_BxL.astype(jnp.float32)) + return ( + total_loss / (num_tokens + self.hps['epsilon']), + num_tokens, + effective_num_tokens, + ) def evaluate_batch(self, params, batch_stats, batch): rng = batch['eval_rng'] - loss = self._compute_elbo(params, batch, rng, train=False) - return self.metrics_bundle.single_from_model_output(normalized_loss=loss) + loss, num_tokens, effective_num_tokens = self._compute_elbo( + params, batch, rng, train=False + ) + return self.metrics_bundle.single_from_model_output( + normalized_loss=loss, + num_tokens=num_tokens, + effective_num_tokens=effective_num_tokens, + ) def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): - loss = self._compute_elbo(params, batch, dropout_rng, train=True) + loss, num_tokens, effective_num_tokens = self._compute_elbo( + params, batch, dropout_rng, train=True + ) if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( @@ -241,4 +257,7 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): ) loss += 0.5 * self.hps.l2_decay_factor * l2_loss - return loss, {} + return loss, { + 'num_tokens': num_tokens, + 'effective_num_tokens': effective_num_tokens, + } diff --git a/init2winit/model_lib/metrics.py b/init2winit/model_lib/metrics.py index 8f615669..c428405d 100644 --- a/init2winit/model_lib/metrics.py +++ b/init2winit/model_lib/metrics.py @@ -214,6 +214,48 @@ def compute(self): return self.count +@flax.struct.dataclass +class NumTokens(Metric): + """Computes the number of tokens seen.""" + + count: jnp.float32 + + @classmethod + def from_model_output(cls, num_tokens=None, **_): + if num_tokens is None: + return cls(count=jnp.array(0.0, dtype=jnp.float32)) + return cls(count=jnp.array(num_tokens, dtype=jnp.float32)) + + def merge(self, other): + """Merges two NumTokens metrics.""" + return type(self)(count=self.count + other.count) + + def compute(self): + """Computes the number of tokens.""" + return self.count + + +@flax.struct.dataclass +class EffectiveNumTokens(Metric): + """Computes the effective number of tokens seen (masked).""" + + count: jnp.float32 + + @classmethod + def from_model_output(cls, effective_num_tokens=None, **_): + if effective_num_tokens is None: + return cls(count=jnp.array(0.0, dtype=jnp.float32)) + return cls(count=jnp.array(effective_num_tokens, dtype=jnp.float32)) + + def merge(self, other): + """Merges two EffectiveNumTokens metrics.""" + return type(self)(count=self.count + other.count) + + def compute(self): + """Computes the effective number of tokens.""" + return self.count + + # Following the Flax OGB example: # https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py @flax.struct.dataclass @@ -835,7 +877,10 @@ def compute(self): num_examples=NumExamples, ), 'mdlm_metrics': Collection.create( - ce_loss=average_ctc_loss(), perplexity=mdlm_perplexity() + ce_loss=average_ctc_loss(), + perplexity=mdlm_perplexity(), + num_tokens=NumTokens, + effective_num_tokens=EffectiveNumTokens, ), } diff --git a/init2winit/model_lib/test_metrics.py b/init2winit/model_lib/test_metrics.py index 581d687e..fb7bbcf2 100644 --- a/init2winit/model_lib/test_metrics.py +++ b/init2winit/model_lib/test_metrics.py @@ -457,13 +457,17 @@ class MDLMMetricsTest(absltest.TestCase): """Tests for the MDLM metrics bundle.""" def test_mdlm_bundle(self): - """MDLM bundle passes through correct ce_loss and perplexity.""" + """MDLM bundle passes through correct metrics.""" bundle = metrics.get_metrics('mdlm_metrics') result = bundle.single_from_model_output( - normalized_loss=jnp.float32(2.0) + normalized_loss=jnp.float32(2.0), + num_tokens=jnp.float32(100.0), + effective_num_tokens=jnp.float32(50.0), ).compute() np.testing.assert_allclose(result['ce_loss'], 2.0) np.testing.assert_allclose(result['perplexity'], jnp.exp(2.0), rtol=1e-5) + np.testing.assert_allclose(result['num_tokens'], 100.0) + np.testing.assert_allclose(result['effective_num_tokens'], 50.0) class MetricsBundleRegistryTest(parameterized.TestCase):