# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict, defaultdict
from typing import Any, cast
import jax
import numpy as np
from jaxpp.jax_compat import core as jcore
from jaxpp.mesh import MpmdMesh
from jaxpp.types import MpmdSharding
from jaxpp.utils import filter_axes, get_named_sharding, update_named_sharding
[docs]
class MpmdArray:
"""An array distributed across one or more MPMD groups.
MpmdArray represents a logical array that exists in a subset of MPMD groups
within an :class:`~jaxpp.mesh.MpmdMesh`.
Unlike standard JAX SPMD arrays where all processes have a shard of the corresponding
array, an MpmdArray is only "partially addressable", it exists as sharded in only
one or more MPMD groups (potentially all of them too).
Properties:
- `mpmd_idxs`: The set of MPMD group indices where this array exists.
Most computed arrays exist in a single group (len=1), but constants
and loop invariants may be replicated across multiple groups when
needed as inputs by multiple pipeline stages.
- `is_partially_addressable`: A process can only access array shards if
it belongs to one of the MPMD groups in mpmd_idxs.
Use `to_mpmd_local_array` to get the SPMD JAX array
(potentially spanning multiple devices) for this MPMD group.
- `is_mpmd_replicated`: When `len(mpmd_idxs) > 1`, the array data is replicated
across those groups.
The `sharding` property returns a NamedSharding whose mesh spans all devices
in mpmd_idxs, useful for resharding between SPMD and MPMD layouts.
Example:
An array with mpmd_idxs={0, 2} on a 4-group MPMD mesh exists in groups
0 and 2. Processes in groups 1 and 3 cannot access this array's data
(is_partially_addressable=False for them).
Attributes:
spec: The PartitionSpec describing how the array is sharded within
each MPMD group.
aval: The abstract value (shape and dtype) of the array.
"""
def __init__(
self,
partially_addressable_arrays: list[jax.Array],
mpmd_sharding: MpmdSharding,
shape: tuple[int, ...] | None = None,
dtype: jax.numpy.dtype | None = None,
):
mpmd_mesh = mpmd_sharding.mpmd_mesh
mpmd_idxs = mpmd_sharding.mesh_ids
spec = mpmd_sharding.spec
self._mpmd_sharding = mpmd_sharding
self._mpmd_mesh = mpmd_mesh
self._mpmd_idxs = tuple(sorted(mpmd_idxs))
if mpmd_mesh.jax_mesh.is_multi_process:
assert len(partially_addressable_arrays) <= 1
partially_addressable_arrays_map = {}
for idx, arr in enumerate(partially_addressable_arrays):
mesh = get_named_sharding(arr).mesh
if (mpmd_idx := mpmd_mesh.mpmd_idx_for_mesh.get(mesh)) is None:
raise ValueError(
f"Argument array {idx} {arr.shape} is not on a mesh that is part"
f" of mpmd_mesh={mpmd_mesh.jax_mesh}"
)
if mpmd_idx not in mpmd_idxs:
raise ValueError(
f"Argument array's ({idx} {arr.shape}) mpmd_idx={mpmd_idx} not "
"in mpmd_idxs={mpmd_idxs}"
)
if mpmd_idx in partially_addressable_arrays_map:
raise ValueError(
f"Argument array {idx} {arr.shape} already has a "
f"mpmd_idx={mpmd_idx}"
)
partially_addressable_arrays_map[mpmd_idx] = arr
self._partially_addressable_arrays: OrderedDict[int, jax.Array] = OrderedDict(
sorted(partially_addressable_arrays_map.items(), key=lambda x: x[0])
)
if len(self._partially_addressable_arrays) == 0:
assert spec is not None
assert shape is not None
assert dtype is not None
else:
first_value = list(self._partially_addressable_arrays.values())[0]
shape = shape if shape is not None else first_value.shape
dtype = dtype if dtype is not None else first_value.dtype
shapes = [a.shape for a in self._partially_addressable_arrays.values()]
assert all(_ == shape for _ in shapes), (shape, shapes)
dtypes = [a.dtype for a in self._partially_addressable_arrays.values()]
assert all(_ == dtype for _ in dtypes), (dtype, dtypes)
mpmd_axis = mpmd_sharding.mpmd_mesh.mpmd_axis_name
specs = [
filter_axes(get_named_sharding(a).spec, {mpmd_axis})
for a in self._partially_addressable_arrays.values()
]
assert all(_ == mpmd_sharding.spec for _ in specs), (
mpmd_sharding.spec,
specs,
)
self._spec = mpmd_sharding.spec
self._sharding = mpmd_sharding.sharding
# TODO: maybe add sharding/vma/memory_space to aval
self.aval = jcore.ShapedArray(shape, dtype, weak_type=False)
@property
def spec(self) -> jax.sharding.PartitionSpec:
return self._spec
@property
def shape(self) -> tuple[int, ...]:
return self.aval.shape
@property
def dtype(self) -> jax.numpy.dtype:
return self.aval.dtype
@property
def ndim(self) -> int:
return len(self.shape)
@property
def sharding(self) -> jax.sharding.NamedSharding:
"""
NOTE: this is different from self.to_mpmd_local_array.sharding
if self.is_mpmd_replicated
"""
return self._sharding
@property
def _mpmd_local_sharding(self) -> jax.sharding.NamedSharding:
return jax.sharding.NamedSharding(self._mpmd_mesh.lowering_mesh(), self.spec)
def __repr__(self):
return (
f"MpmdArray(shape={self.shape}, dtype={self.dtype}, "
f"mpmd_idxs={self._mpmd_idxs}, sharding={self._sharding})"
)
@property
def is_mpmd_replicated(self) -> bool:
"""
Returns True if the array is replicated in more than one mpmd rank.
"""
return len(self._mpmd_idxs) > 1
@property
def is_partially_addressable(self) -> bool:
"""
Returns True if the array is partially addressable in the mpmd rank
this process participates in.
An array is partially addressable at this rank if this rank holds a shard of
the array (the shard can potentially be replicated across multiple mpmd ranks).
"""
return len(self._partially_addressable_arrays) > 0
def delete(self):
assert self.is_partially_addressable, "Array is not partially addressable"
assert not self.is_deleted(), "Array is deleted"
for arr in self._partially_addressable_arrays.values():
arr.delete()
def is_deleted(self) -> bool:
assert self.is_partially_addressable, "Array is not partially addressable"
if len(self._partially_addressable_arrays) == 1:
return next(iter(self._partially_addressable_arrays.items()))[
1
].is_deleted()
_ = [a.is_deleted() for a in self._partially_addressable_arrays.values()]
deleted = any(_)
assert deleted == all(_)
return deleted
@property
def to_mpmd_local_array(self) -> jax.Array | list[jax.Array] | None:
"""
Returns a jax.Array if the array is partially addressable in the mpmd rank
this process participates in.
Otherwise, returns None.
Returns a list of arrays when it's a single process, multiple-devices mesh.
"""
if not self.is_partially_addressable:
return None
assert not self.is_deleted(), "Array is deleted"
els = list(self._partially_addressable_arrays.values())
if len(els) == 1:
return els[0]
return els
@property
def first_mpmd_replica(self) -> jax.Array | None:
if not self.is_partially_addressable:
return None
assert not self.is_deleted(), "Array is deleted"
mpmd_idx, array = next(iter(self._partially_addressable_arrays.items()))
if mpmd_idx == self._mpmd_idxs[0]:
return array
return None
def __int__(self):
assert self.is_partially_addressable, "Array is not partially addressable"
return int(self.to_mpmd_local_array)
def __format__(self, format_spec):
assert self.is_partially_addressable, "Array is not partially addressable"
return format(self.to_mpmd_local_array, format_spec)
def block_until_ready(self):
for arr in self._partially_addressable_arrays.values():
arr.block_until_ready()
return self
def pytype_aval_mapping(self: MpmdArray) -> jcore.AbstractValue:
aval = self.aval
if hasattr(aval, "sharding"):
return jcore.update_aval_with_sharding(self.aval, self._mpmd_local_sharding)
return aval
jcore.pytype_aval_mappings[MpmdArray] = pytype_aval_mapping
def _to_global_jax_array(mpmd_array: MpmdArray) -> jax.Array | None:
if not mpmd_array.is_partially_addressable:
if getattr(
jax.config, "jax_enable_empty_arrays", False
) or jax.__version_info__ >= (0, 7, 1):
return jax.make_array_from_single_device_arrays(
shape=mpmd_array.shape,
sharding=mpmd_array._sharding,
arrays=[],
dtype=mpmd_array.dtype,
)
return None
return jax.make_array_from_single_device_arrays(
shape=mpmd_array.shape,
sharding=mpmd_array._sharding,
arrays=[
shard.data
for arr in mpmd_array._partially_addressable_arrays.values()
for shard in arr.addressable_shards
],
dtype=mpmd_array.dtype,
)
def _id(*xs):
return xs
def _spmd_to_mpmd_reshard(
mpmd_mesh: MpmdMesh,
spmd_values: list[jax.Array],
dist_shardings: list[MpmdSharding],
donate: list[bool] | None = None,
) -> list[MpmdArray]:
if donate is None:
donate = [False] * len(spmd_values)
for spmd_value, dist_sharding in zip(spmd_values, dist_shardings):
assert isinstance(
spmd_value.sharding,
(jax.sharding.NamedSharding, jax.sharding.SingleDeviceSharding),
), f"Unsupported sharding type: {spmd_value.sharding}"
# NOTE: We filter out the mpmd axis from the sharding so that
# the output is replicated across all mpmd ranks.
_actual_shardings = tuple(
update_named_sharding(
filter_axes(
cast(jax.sharding.NamedSharding, dist_sharding.sharding),
{mpmd_mesh.mpmd_axis_name},
),
mesh=mpmd_mesh.jax_mesh,
)
for dist_sharding in dist_shardings
)
res: list[jax.Array] = jax.jit(_id, out_shardings=_actual_shardings)(*spmd_values)
for spmd_value, donated in zip(spmd_values, donate, strict=True):
if donated:
spmd_value.delete()
if not mpmd_mesh.jax_mesh.is_multi_process:
# TODO: return an jaxpp.MpmdArray instead of a list of jax.Array
_res = []
for _, dsh in zip(res, dist_shardings, strict=True):
shards = []
for s in _.addressable_shards:
if mpmd_mesh.device_mpmd_idx[s.device] in dsh.mesh_ids:
shards.append(s.data)
else:
s.data.delete()
_arr = jax.make_array_from_single_device_arrays(
_.shape,
jax.sharding.NamedSharding(
mpmd_mesh.mpmd_submesh(list(dsh.mesh_ids)).jax_mesh,
_.sharding.spec,
),
shards,
)
_res.append(_arr)
return _res
_res = []
for arr, dsh in zip(res, dist_shardings, strict=True):
mesh_ids = dsh.mesh_ids
# MpmdSharding.__post_init__ canonicalizes the spec by filtering out
# the mpmd axis, so we can just use dsh.spec directly
mpmd_sharding = MpmdSharding(
mpmd_mesh=dsh.mpmd_mesh, mesh_ids=dsh.mesh_ids, spec=dsh.spec
)
if mpmd_mesh.my_mpmd_axis_index not in mesh_ids:
_res.append(
MpmdArray(
partially_addressable_arrays=[],
mpmd_sharding=mpmd_sharding,
shape=arr.shape,
dtype=arr.dtype,
)
)
arr.delete()
else:
new_arr = jax.make_array_from_single_device_arrays(
arr.shape,
jax.sharding.NamedSharding(
mpmd_mesh.my_mpmd_group_mesh, arr.sharding.spec
),
[s.data for s in arr.addressable_shards],
)
_res.append(
MpmdArray(
partially_addressable_arrays=[new_arr], mpmd_sharding=mpmd_sharding
)
)
return _res
def _get_working_memory_threshold() -> int:
"""Get the minimum available working memory across all devices globally."""
min_available = float("inf")
for d in jax.local_devices():
stats = d.memory_stats()
available = stats["bytes_limit"] - stats["peak_bytes_in_use"]
min_available = min(min_available, available)
# Ensure all processes use the same threshold by computing global minimum
if jax.process_count() > 1:
# Use process_allgather to collect all local minimums, then compute global min
# Note: process_allgather requires an array, not a scalar
local_min_array = jax.numpy.array(min_available, dtype=jax.numpy.float32)
all_mins = jax.experimental.multihost_utils.process_allgather(
local_min_array, tiled=True
)
min_available = float(jax.numpy.min(all_mins))
return int(min_available // 3)
def _build_mpmd_interleaved_order(
arrays: list[jax.Array],
shardings: list[MpmdSharding],
) -> list[int]:
"""Build array order interleaved by mpmd_idx, largest first within each idx."""
by_mpmd_idx: dict[int, list[int]] = defaultdict(list)
for i, dsh in enumerate(shardings):
by_mpmd_idx[min(dsh.mesh_ids)].append(i)
# Sort each mpmd_idx group by size (largest last for popping)
by_mpmd_idx = {
mpmd_idx: sorted(vs, key=lambda i: arrays[i].addressable_shards[0].data.nbytes)
for mpmd_idx, vs in by_mpmd_idx.items()
}
# Build order by round-robin popping largest from each mpmd_idx
order: list[int] = []
while any(by_mpmd_idx.values()):
for mpmd_idx in list(by_mpmd_idx.keys()):
if by_mpmd_idx[mpmd_idx]:
order.append(by_mpmd_idx[mpmd_idx].pop())
return order
[docs]
def spmd_to_mpmd_reshard(
mpmd_mesh: MpmdMesh, spmd_arrays, mpmd_shardings, threshold: int | None = None
):
"""
Reshards a pytree of SPMD arrays to MPMD arrays.
This function redistributes data from a Single Program Multiple Data (SPMD)
layout to a Multiple Program Multiple Data (MPMD) layout. It handles
memory constraints by grouping arrays and processing them in chunks.
It's the caller's responsibility to not use the input spmd_arrays after calling
this function as they will be consumed by this function.
The specs of the returned arrays will _not_ have `mpmd_mesh.mpmd_axis_name` in
them.
Limitations: same constraints as jax.jit apply (e.g. _device_assignment must be the
same for all arrays)
Args:
mpmd_mesh: The MPMD mesh definition.
spmd_arrays: A pytree of source SPMD arrays.
mpmd_shardings: A pytree of target MPMD shardings, matching the structure of
spmd_arrays.
threshold: Memory threshold in bytes for grouping operations.
If None, calculated based on available memory.
Returns:
A pytree of MpmdArray objects with the same structure as spmd_arrays.
"""
spmd_arrays_with_path, spmd_tree_def = jax.tree.flatten_with_path(spmd_arrays)
mpmd_shardings_flat, mpmd_tree_def = jax.tree.flatten(mpmd_shardings)
# For unused arrays (len(dsh.mesh_ids) == 0), we default their placement
# to mpmd rank 0
mpmd_shardings_flat = [
MpmdSharding(mpmd_mesh=dsh.mpmd_mesh, mesh_ids={0}, spec=dsh.spec)
if len(dsh.mesh_ids) == 0
else dsh
for dsh in mpmd_shardings_flat
]
assert spmd_tree_def == mpmd_tree_def
_, spmd_arrays_flat = jax._src.util.unzip2(spmd_arrays_with_path)
spmd_arrays_flat_list = list(spmd_arrays_flat)
# Build interleaved order by mpmd_idx (largest first within each idx)
order = _build_mpmd_interleaved_order(spmd_arrays_flat_list, mpmd_shardings_flat)
ordered_arrays = [spmd_arrays_flat_list[i] for i in order]
# Group by memory threshold
threshold = threshold if threshold is not None else _get_working_memory_threshold()
groups = _group_by_size_threshold(
[
a.addressable_shards[0].data.nbytes * mpmd_mesh.mpmd_dim
for a in ordered_arrays
],
threshold,
)
resharded_arrays_by_index: dict[int, MpmdArray] = {}
for group_idx, group_indices in enumerate(groups):
# Map group indices back to original indices
orig_indices = [order[i] for i in group_indices]
group_arrays = [spmd_arrays_flat_list[i] for i in orig_indices]
group_mpmd_shardings = [mpmd_shardings_flat[i] for i in orig_indices]
group_results = _spmd_to_mpmd_reshard(
mpmd_mesh,
group_arrays,
group_mpmd_shardings,
donate=[True] * len(group_arrays), # FIXME: Maybe make it a kwarg
)
for orig_idx, result in zip(orig_indices, group_results):
resharded_arrays_by_index[orig_idx] = result
resharded_flat = [
resharded_arrays_by_index[i] for i in range(len(spmd_arrays_flat_list))
]
return jax.tree.unflatten(spmd_tree_def, resharded_flat)
def _group_by_size_threshold(
sizes: list[int],
threshold: int,
) -> list[list[int]]:
"""Group indices by size threshold.
Groups are formed sequentially - entries are added to the current group
until adding another would exceed the threshold, then a new group starts.
Returns groups of indices into the input sizes list.
"""
groups: list[list[int]] = []
current_group: list[int] = []
current_size = 0
for i, size in enumerate(sizes):
if current_size + size > threshold and current_group:
groups.append(current_group)
current_group = []
current_size = 0
current_group.append(i)
current_size += size
if current_group:
groups.append(current_group)
return groups
def _axis_name_in_spec(axis_name: str, spec) -> bool:
for elem in spec:
if elem == axis_name:
return True
if isinstance(elem, tuple) and axis_name in elem:
return True
return False
def logically_stacked(
array: jax.Array,
comm_mesh: jax.sharding.Mesh,
mesh_axis_name: str,
array_axis: int = 0,
strict: bool = False,
):
"""
Logically stacks an array along a new axis corresponding to the MPMD dimension.
This function expands the input array's dimensions and reshards it across
the communication mesh, effectively treating distributed shards as a single
logical array with an extra dimension.
"""
assert isinstance(array.sharding, jax.sharding.NamedSharding)
if strict:
spec = array.sharding.spec
assert not _axis_name_in_spec(
mesh_axis_name, spec
), f"axis_name {mesh_axis_name!r} already exists in spec {spec}"
else:
spec = filter_axes(array.sharding.spec, {mesh_axis_name})
expanded_array = jax.numpy.expand_dims(array, array_axis)
in_sharding = jax.sharding.NamedSharding(
comm_mesh,
jax.sharding.PartitionSpec(
*spec[:array_axis], mesh_axis_name, *spec[array_axis:]
),
)
global_array = jax.make_array_from_single_device_arrays(
(
*array.shape[:array_axis],
comm_mesh.shape[mesh_axis_name],
*array.shape[array_axis:],
),
in_sharding,
[s.data for s in expanded_array.addressable_shards],
)
return global_array
def _select_mpmd_slice(arrays, mpmd_idxs):
"""Selector function to pick the slice corresponding to the MPMD index."""
return tuple(array[idx] for array, idx in zip(arrays, mpmd_idxs, strict=True))
[docs]
def mpmd_to_spmd_reshard(
mpmd_mesh: MpmdMesh, mpmd_arrays, spmd_shardings, threshold: int | None = None
) -> jax.Array:
"""
Reshards a pytree of MPMD arrays to SPMD arrays.
This function redistributes data from a Multiple Program Multiple Data (MPMD)
layout back to a Single Program Multiple Data (SPMD) layout. It reconstructs
global arrays from distributed MPMD shards.
It's the caller's responsibility to not use the input mpmd_arrays after calling
this function as they will be consumed by this function.
Args:
mpmd_mesh: The MPMD mesh definition.
mpmd_arrays: A pytree of source MPMD arrays.
spmd_shardings: A pytree of target SPMD shardings.
threshold: Memory threshold in bytes for grouping operations.
If None, calculated based on available memory.
Returns:
A pytree of JAX arrays with the same structure as mpmd_arrays.
"""
if not mpmd_mesh.jax_mesh.is_multi_process:
return jax.device_put(
jax.tree.map(lambda _: _.first_mpmd_replica, mpmd_arrays), spmd_shardings
)
mpmd_arrays_with_path, mpmd_tree_def = jax.tree.flatten_with_path(mpmd_arrays)
mpmd_arrays_with_path: list[tuple[Any, MpmdArray]]
spmd_shardings_flat, spmd_tree_def = jax.tree.flatten(spmd_shardings)
assert mpmd_tree_def == spmd_tree_def
donate = [True] * len(mpmd_arrays_with_path) # FIXME: Maybe make it a kwarg
# Collect metadata without building stacked arrays yet
mpmd_arr_list = [mpmd_arr for _, mpmd_arr in mpmd_arrays_with_path]
mpmd_idxs = [mpmd_arr._mpmd_idxs[0] for mpmd_arr in mpmd_arr_list]
def get_shard_size(mpmd_arr: MpmdArray) -> int:
return int(
np.prod(mpmd_arr._mpmd_local_sharding.shard_shape(mpmd_arr.shape))
* np.dtype(mpmd_arr.dtype).itemsize
)
sizes = [get_shard_size(mpmd_arr) for mpmd_arr in mpmd_arr_list]
# Sort by size (largest first) for better memory efficiency
order = sorted(range(len(mpmd_arr_list)), key=lambda i: sizes[i], reverse=True)
groups = _group_by_size_threshold(
[sizes[i] * mpmd_mesh.mpmd_dim for i in order],
threshold if threshold is not None else _get_working_memory_threshold(),
)
resharded_arrays_by_index: dict[int, jax.Array] = {}
for group_indices in groups:
orig_indices = [order[i] for i in group_indices]
# Build stacked arrays for this group
group_stacked = []
for orig_idx in orig_indices:
mpmd_arr = mpmd_arr_list[orig_idx]
local_array = mpmd_arr.first_mpmd_replica
if local_array is None:
# Create zeros if this rank doesn't hold data for this array
local_array = jax.jit(
jax.numpy.zeros,
static_argnums=(0, 1),
out_shardings=mpmd_arr._mpmd_local_sharding,
)(mpmd_arr.shape, mpmd_arr.dtype)
stacked = logically_stacked(
local_array, mpmd_mesh.jax_mesh, mpmd_mesh.mpmd_axis_name
)
# logically_stacked creates a new array, so we can delete the local array
if donate[orig_idx]:
local_array.delete()
group_stacked.append(stacked)
group_stacked = tuple(group_stacked)
group_mpmd_idxs = tuple(mpmd_idxs[i] for i in orig_indices)
group_spmd_shardings = tuple(spmd_shardings_flat[i] for i in orig_indices)
in_shardings = (tuple(_.sharding for _ in group_stacked),)
group_results = jax.jit(
_select_mpmd_slice,
in_shardings=in_shardings,
out_shardings=group_spmd_shardings,
static_argnums=(1,),
)(group_stacked, group_mpmd_idxs)
for i, orig_idx in enumerate(orig_indices):
resharded_arrays_by_index[orig_idx] = group_results[i]
group_stacked[i].delete()
resharded_flat = [resharded_arrays_by_index[i] for i in range(len(mpmd_arr_list))]
return jax.tree.unflatten(spmd_tree_def, resharded_flat)