diff --git a/src/maxtext/input_pipeline/multihost_dataloading.py b/src/maxtext/input_pipeline/multihost_dataloading.py index 7757c27313..9e89447f09 100644 --- a/src/maxtext/input_pipeline/multihost_dataloading.py +++ b/src/maxtext/input_pipeline/multihost_dataloading.py @@ -18,6 +18,7 @@ Adapted from Sholto's: https://github.com/sholtodouglas/multihost_dataloading """ +import logging from functools import partial from typing import Union, Sequence from collections.abc import Iterator, Iterable @@ -198,6 +199,11 @@ def reset(self): def get_next(self, dummy_array): """Gets the next batch of data and forms a global array.""" local_data = next(self.iterator) + logging.info( + "RemoteIterator get_next got local_data['inputs'] of shape %s, will split into %s devices", + local_data["inputs"].shape, + len(list(dummy_array.sharding.addressable_devices)), + ) def form_global_array_colocated_python(path, array, devices, global_shape, sharding): try: @@ -223,6 +229,13 @@ def form_global_array_colocated_python(path, array, devices, global_shape, shard def save_state(self, step_array): """Saves the iterator state to a file.""" + logging.info( + "RemoteIterator save_state received step_array shape %s, num addressable shards %s, shard 0 shape %s", + str(step_array.shape), + str(len(step_array.addressable_shards)), + str(step_array.addressable_data(0).shape), + ) + logging.info("RemoteIterator step_array.addressable_data(0) is %s", str(step_array.addressable_data(0))) step = step_array.addressable_data(0).item() directory = epath.Path(self.checkpoint_path) / str(step) / "iter" if self.elastic: @@ -242,6 +255,14 @@ def save_state(self, step_array): return step_array def restore_state(self, step_array): + """Restore the iterator state from a checkpoint.""" + logging.info( + "RemoteIterator restore_state received step_array shape %s, num addressable shards %s, shard 0 shape %s", + str(step_array.shape), + str(len(step_array.addressable_shards)), + str(step_array.addressable_data(0).shape), + ) + logging.info("RemoteIterator step_array.addressable_data(0) is %s", str(step_array.addressable_data(0))) step = step_array.addressable_data(0).item() directory = epath.Path(self.checkpoint_path) / str(step) / "iter" if self.elastic: @@ -257,12 +278,16 @@ class RemoteIteratorWrapper: """Wrapper for RemoteIterator that handles device placement.""" def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path="", elastic=False): + max_logging.log(f"RemoteIteratorWrapper: received global_mesh = {global_mesh}") + max_logging.log(f"RemoteIteratorWrapper: received global_shape = {global_shape}") self.cpu_devices = _colocated_cpu_devices(jax.local_devices()) self.tpu_devices = jax.local_devices() self.cpu_mesh = _colocated_cpu_mesh(global_mesh) self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names)) self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names)) self.dummy_array = jnp.zeros((len(self.cpu_devices))) + max_logging.log(f"RemoteIteratorWrapper: number of cpu devices {len(self.cpu_devices)}") + max_logging.log(f"RemoteIteratorWrapper: cpu_sharding: {self.cpu_sharding}") self.dummy_array = jax.device_put(self.dummy_array, self.cpu_sharding) # This is a proxy to a RemoteIterator running in a colocated process, # named "local_iterator" to match MultiHostDataLoadIterator's interface. @@ -282,10 +307,12 @@ def __next__(self): def save_state(self, step): step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32) + max_logging.log(f"RemoteIteratorWrapper: calling saving_state with {step_array.shape=}") step_array = jax.device_put(step_array, self.cpu_sharding) self.local_iterator.save_state(step_array) def restore_state(self, step): step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32) + max_logging.log(f"RemoteIteratorWrapper: calling restore_state with {step_array.shape=}") step_array = jax.device_put(step_array, self.cpu_sharding) self.local_iterator.restore_state(step_array)