Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions python/mlx/_distributed_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ def tb_connectivity_to_dot(hosts, tb_hosts, uuid_reverse_index):
for p in h.ports:
if not p.connected_to:
continue
if p.connected_to not in uuid_reverse_index:
continue
dst = uuid_reverse_index[p.connected_to]
if dst[0] < i:
continue
Expand Down Expand Up @@ -365,7 +367,7 @@ def check_valid_ring(hosts, rings, strict=True):
return has_ring


def check_ssh_connections(hosts):
def check_ssh_connections(hosts, ignore_unreachable=False):
results = [None] * len(hosts)

def _check(hostname, i):
Expand Down Expand Up @@ -417,7 +419,7 @@ def _check(hostname, i):
for t in threads:
t.join()

if not all(results):
if not all(results) and not ignore_unreachable:
log_error("Could not ssh to the following hosts:")
for i, h in enumerate(hosts):
if not results[i]:
Expand Down Expand Up @@ -493,13 +495,13 @@ def configure_jaccl_ring(args, hosts, ips, ring, sshinfo):
peer_left = ring[i - 1]
peer_right = ring[(i + 1) % num_nodes]
rdmas = []
for j in range(len(hosts)):
if j not in (peer_left, peer_right):
for other in ring:
if other not in (peer_left, peer_right):
rdmas.append(None)
else:
rdma = []
for c in range(count):
rdma.append(f"rdma_{ips.ips[i, j][c][0]}")
rdma.append(f"rdma_{ips.ips[node, other][c][0]}")
rdmas.append(rdma[0] if count == 1 else rdma)
jaccl_hosts.append(Host(i, h.ssh_hostname, h.ips, rdmas))
hostfile = Hostfile(jaccl_hosts, "jaccl-ring", args.env)
Expand Down Expand Up @@ -573,6 +575,11 @@ def main():
parser.add_argument(
"--hosts", default="127.0.0.1", help="A comma separated list of hosts"
)
parser.add_argument(
"--ignore-unreachable",
action="store_true",
help="Ignore hosts that are not reachable via ssh",
)
parser.add_argument("--hostfile", help="The file containing the hosts")
parser.add_argument(
"--over",
Expand Down Expand Up @@ -619,7 +626,9 @@ def main():
args.verbose,
f"Checking for ssh access for {', '.join(h.ssh_hostname for h in hosts)}",
)
sshinfo = check_ssh_connections(hosts)
sshinfo = check_ssh_connections(hosts, args.ignore_unreachable)
hosts = [h for r, h in zip(sshinfo, hosts) if r]
sshinfo = [r for r in sshinfo if r]

# Prepare a hostfile for communication over ethernet using the ips of the
# provided hostnames
Expand Down
Loading