Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions init2winit/model_lib/mdlm_rope_nanodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,41 @@ def _compute_elbo(self, params, batch, rng, train):
total_loss = jnp.sum(weighted_loss_BxL)
if pad_weights is not None:
num_tokens = jnp.sum(pad_weights)
effective_num_tokens = jnp.sum(
is_masked_BxL.astype(jnp.float32) * pad_weights
)
else:
num_tokens = jnp.array(B * L, dtype=jnp.float32)
return total_loss / (num_tokens + self.hps['epsilon'])
effective_num_tokens = jnp.sum(is_masked_BxL.astype(jnp.float32))
return (
total_loss / (num_tokens + self.hps['epsilon']),
num_tokens,
effective_num_tokens,
)

def evaluate_batch(self, params, batch_stats, batch):
rng = batch['eval_rng']
loss = self._compute_elbo(params, batch, rng, train=False)
return self.metrics_bundle.single_from_model_output(normalized_loss=loss)
loss, num_tokens, effective_num_tokens = self._compute_elbo(
params, batch, rng, train=False
)
return self.metrics_bundle.single_from_model_output(
normalized_loss=loss,
num_tokens=num_tokens,
effective_num_tokens=effective_num_tokens,
)

def training_cost(self, params, batch, batch_stats=None, dropout_rng=None):
loss = self._compute_elbo(params, batch, dropout_rng, train=True)
loss, num_tokens, effective_num_tokens = self._compute_elbo(
params, batch, dropout_rng, train=True
)

if self.hps.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(
params, self.hps.l2_decay_rank_threshold
)
loss += 0.5 * self.hps.l2_decay_factor * l2_loss

return loss, {}
return loss, {
'num_tokens': num_tokens,
'effective_num_tokens': effective_num_tokens,
}
47 changes: 46 additions & 1 deletion init2winit/model_lib/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,48 @@ def compute(self):
return self.count


@flax.struct.dataclass
class NumTokens(Metric):
"""Computes the number of tokens seen."""

count: jnp.float32

@classmethod
def from_model_output(cls, num_tokens=None, **_):
if num_tokens is None:
return cls(count=jnp.array(0.0, dtype=jnp.float32))
return cls(count=jnp.array(num_tokens, dtype=jnp.float32))

def merge(self, other):
"""Merges two NumTokens metrics."""
return type(self)(count=self.count + other.count)

def compute(self):
"""Computes the number of tokens."""
return self.count


@flax.struct.dataclass
class EffectiveNumTokens(Metric):
"""Computes the effective number of tokens seen (masked)."""

count: jnp.float32

@classmethod
def from_model_output(cls, effective_num_tokens=None, **_):
if effective_num_tokens is None:
return cls(count=jnp.array(0.0, dtype=jnp.float32))
return cls(count=jnp.array(effective_num_tokens, dtype=jnp.float32))

def merge(self, other):
"""Merges two EffectiveNumTokens metrics."""
return type(self)(count=self.count + other.count)

def compute(self):
"""Computes the effective number of tokens."""
return self.count


# Following the Flax OGB example:
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py
@flax.struct.dataclass
Expand Down Expand Up @@ -835,7 +877,10 @@ def compute(self):
num_examples=NumExamples,
),
'mdlm_metrics': Collection.create(
ce_loss=average_ctc_loss(), perplexity=mdlm_perplexity()
ce_loss=average_ctc_loss(),
perplexity=mdlm_perplexity(),
num_tokens=NumTokens,
effective_num_tokens=EffectiveNumTokens,
),
}

Expand Down
8 changes: 6 additions & 2 deletions init2winit/model_lib/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,17 @@ class MDLMMetricsTest(absltest.TestCase):
"""Tests for the MDLM metrics bundle."""

def test_mdlm_bundle(self):
"""MDLM bundle passes through correct ce_loss and perplexity."""
"""MDLM bundle passes through correct metrics."""
bundle = metrics.get_metrics('mdlm_metrics')
result = bundle.single_from_model_output(
normalized_loss=jnp.float32(2.0)
normalized_loss=jnp.float32(2.0),
num_tokens=jnp.float32(100.0),
effective_num_tokens=jnp.float32(50.0),
).compute()
np.testing.assert_allclose(result['ce_loss'], 2.0)
np.testing.assert_allclose(result['perplexity'], jnp.exp(2.0), rtol=1e-5)
np.testing.assert_allclose(result['num_tokens'], 100.0)
np.testing.assert_allclose(result['effective_num_tokens'], 50.0)


class MetricsBundleRegistryTest(parameterized.TestCase):
Expand Down
Loading