Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/maxtext/input_pipeline/multihost_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Adapted from Sholto's:
https://github.com/sholtodouglas/multihost_dataloading
"""
import logging
from functools import partial
from typing import Union, Sequence
from collections.abc import Iterator, Iterable
Expand Down Expand Up @@ -198,6 +199,11 @@ def reset(self):
def get_next(self, dummy_array):
"""Gets the next batch of data and forms a global array."""
local_data = next(self.iterator)
logging.info(
"RemoteIterator get_next got local_data['inputs'] of shape %s, will split into %s devices",
local_data["inputs"].shape,
len(list(dummy_array.sharding.addressable_devices)),
)
Comment on lines +202 to +206
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not max_logging?


def form_global_array_colocated_python(path, array, devices, global_shape, sharding):
try:
Expand All @@ -223,6 +229,13 @@ def form_global_array_colocated_python(path, array, devices, global_shape, shard

def save_state(self, step_array):
"""Saves the iterator state to a file."""
logging.info(
"RemoteIterator save_state received step_array shape %s, num addressable shards %s, shard 0 shape %s",
str(step_array.shape),
str(len(step_array.addressable_shards)),
str(step_array.addressable_data(0).shape),
)
logging.info("RemoteIterator step_array.addressable_data(0) is %s", str(step_array.addressable_data(0)))
step = step_array.addressable_data(0).item()
directory = epath.Path(self.checkpoint_path) / str(step) / "iter"
if self.elastic:
Expand All @@ -242,6 +255,14 @@ def save_state(self, step_array):
return step_array

def restore_state(self, step_array):
"""Restore the iterator state from a checkpoint."""
logging.info(
"RemoteIterator restore_state received step_array shape %s, num addressable shards %s, shard 0 shape %s",
str(step_array.shape),
str(len(step_array.addressable_shards)),
str(step_array.addressable_data(0).shape),
)
logging.info("RemoteIterator step_array.addressable_data(0) is %s", str(step_array.addressable_data(0)))
step = step_array.addressable_data(0).item()
directory = epath.Path(self.checkpoint_path) / str(step) / "iter"
if self.elastic:
Expand All @@ -257,12 +278,16 @@ class RemoteIteratorWrapper:
"""Wrapper for RemoteIterator that handles device placement."""

def __init__(self, get_ds_fn, preprocessing_fn, global_mesh, global_shape, checkpoint_path="", elastic=False):
max_logging.log(f"RemoteIteratorWrapper: received global_mesh = {global_mesh}")
max_logging.log(f"RemoteIteratorWrapper: received global_shape = {global_shape}")
self.cpu_devices = _colocated_cpu_devices(jax.local_devices())
self.tpu_devices = jax.local_devices()
self.cpu_mesh = _colocated_cpu_mesh(global_mesh)
self.tpu_sharding = jax.sharding.NamedSharding(global_mesh, PartitionSpec(global_mesh.axis_names))
self.cpu_sharding = jax.sharding.NamedSharding(self.cpu_mesh, PartitionSpec(self.cpu_mesh.axis_names))
self.dummy_array = jnp.zeros((len(self.cpu_devices)))
max_logging.log(f"RemoteIteratorWrapper: number of cpu devices {len(self.cpu_devices)}")
max_logging.log(f"RemoteIteratorWrapper: cpu_sharding: {self.cpu_sharding}")
self.dummy_array = jax.device_put(self.dummy_array, self.cpu_sharding)
# This is a proxy to a RemoteIterator running in a colocated process,
# named "local_iterator" to match MultiHostDataLoadIterator's interface.
Expand All @@ -282,10 +307,12 @@ def __next__(self):

def save_state(self, step):
step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32)
max_logging.log(f"RemoteIteratorWrapper: calling saving_state with {step_array.shape=}")
step_array = jax.device_put(step_array, self.cpu_sharding)
self.local_iterator.save_state(step_array)

def restore_state(self, step):
step_array = jnp.full(self.dummy_array.shape, step, dtype=jnp.int32)
max_logging.log(f"RemoteIteratorWrapper: calling restore_state with {step_array.shape=}")
step_array = jax.device_put(step_array, self.cpu_sharding)
self.local_iterator.restore_state(step_array)
Loading