api#

class jaxpp.api.BaseSchedule(num_stages: int)[source]#
class jaxpp.api.Concat(axis: int = 0)[source]#

Represents a concatenation operation along a specified axis.

class jaxpp.api.DualPipeV(num_stages: int, mpmd_dim: int)[source]#
class jaxpp.api.Eager1F1B(num_stages: int)[source]#
class jaxpp.api.Interleaved1F1B(
num_stages: int,
mpmd_dim: int,
fuse_steady_state: bool = False,
)[source]#
class jaxpp.api.MpmdArray(
partially_addressable_arrays: list[Array],
mpmd_sharding: MpmdSharding,
shape: tuple[int, ...] | None = None,
dtype: dtype | None = None,
)[source]#

An array distributed across one or more MPMD groups.

MpmdArray represents a logical array that exists in a subset of MPMD groups within an MpmdMesh. Unlike standard JAX SPMD arrays where all processes have a shard of the corresponding array, an MpmdArray is only “partially addressable”, it exists as sharded in only one or more MPMD groups (potentially all of them too).

Properties:
  • mpmd_idxs: The set of MPMD group indices where this array exists. Most computed arrays exist in a single group (len=1), but constants and loop invariants may be replicated across multiple groups when needed as inputs by multiple pipeline stages.

  • is_partially_addressable: A process can only access array shards if it belongs to one of the MPMD groups in mpmd_idxs. Use to_mpmd_local_array to get the SPMD JAX array (potentially spanning multiple devices) for this MPMD group.

  • is_mpmd_replicated: When len(mpmd_idxs) > 1, the array data is replicated across those groups.

The sharding property returns a NamedSharding whose mesh spans all devices in mpmd_idxs, useful for resharding between SPMD and MPMD layouts.

Example:

An array with mpmd_idxs={0, 2} on a 4-group MPMD mesh exists in groups 0 and 2. Processes in groups 1 and 3 cannot access this array’s data (is_partially_addressable=False for them).

Attributes:
spec: The PartitionSpec describing how the array is sharded within

each MPMD group.

aval: The abstract value (shape and dtype) of the array.

property is_mpmd_replicated: bool#

Returns True if the array is replicated in more than one mpmd rank.

property is_partially_addressable: bool#

Returns True if the array is partially addressable in the mpmd rank this process participates in. An array is partially addressable at this rank if this rank holds a shard of the array (the shard can potentially be replicated across multiple mpmd ranks).

property sharding: NamedSharding#
NOTE: this is different from self.to_mpmd_local_array.sharding

if self.is_mpmd_replicated

property to_mpmd_local_array: Array | list[Array] | None#

Returns a jax.Array if the array is partially addressable in the mpmd rank this process participates in. Otherwise, returns None. Returns a list of arrays when it’s a single process, multiple-devices mesh.

class jaxpp.api.MpmdMesh(jax_mesh: Mesh, mpmd_axis_name: str)[source]#

A JAX mesh partitioned into MPMD (Multiple Program Multiple Data) groups.

MpmdMesh wraps a standard JAX mesh and designates one axis as the “MPMD axis”. The mesh is conceptually split into multiple independent groups along this axis, where each group can execute different computations (e.g., pipeline stages).

For example, with a mesh of shape {‘mpmd’: 4, ‘data’: 2, ‘model’: 2} and mpmd_axis_name=’mpmd’, the mesh is split into 4 MPMD groups, each containing 4 devices (2 data x 2 model). Each group runs its own computation, and arrays can be distributed across one or more groups.

Key concepts:
  • MPMD group: A slice of the mesh along the MPMD axis. Each group has an index from 0 to mpmd_dim - 1.

  • Submesh: A subset of MPMD groups combined into a single mesh. Used when arrays are replicated across multiple groups.

  • Lowering mesh: The mesh used for XLA compilation, which is the local process’s MPMD group mesh in multi-process settings.

In multi-process execution, each process belongs to exactly one MPMD group. Arrays may be replicated across multiple groups when needed as inputs by multiple pipeline stages (common for constants and loop invariants).

Attributes:

jax_mesh: The underlying JAX mesh containing all devices. mpmd_axis_name: Name of the axis used to partition into MPMD groups.

class jaxpp.api.Std1F1B(num_stages: int)[source]#
class jaxpp.api.ZeroBubble(num_stages: int)[source]#
jaxpp.api.collect_task_times_ms(
enabled: bool = True,
) Generator[dict[str, list[float]] | None, None, None][source]#

Context manager to collect task execution times in milliseconds.

Example usage:

with collect_task_times_ms() as stats:
    # ... run tasks ...

for task_name, times in stats.items():
    print(f"{task_name}: {times}")

Example usage with collection disabled:

with collect_task_times_ms(enabled=False) as stats:
    # ... run tasks ...

assert stats is None
jaxpp.api.mpmd_to_spmd_reshard(
mpmd_mesh: MpmdMesh,
mpmd_arrays,
spmd_shardings,
threshold: int | None = None,
) Array[source]#

Reshards a pytree of MPMD arrays to SPMD arrays.

This function redistributes data from a Multiple Program Multiple Data (MPMD) layout back to a Single Program Multiple Data (SPMD) layout. It reconstructs global arrays from distributed MPMD shards. It’s the caller’s responsibility to not use the input mpmd_arrays after calling this function as they will be consumed by this function.

Args:

mpmd_mesh: The MPMD mesh definition. mpmd_arrays: A pytree of source MPMD arrays. spmd_shardings: A pytree of target SPMD shardings. threshold: Memory threshold in bytes for grouping operations.

If None, calculated based on available memory.

Returns:

A pytree of JAX arrays with the same structure as mpmd_arrays.

jaxpp.api.spmd_to_mpmd_reshard(
mpmd_mesh: MpmdMesh,
spmd_arrays,
mpmd_shardings,
threshold: int | None = None,
)[source]#

Reshards a pytree of SPMD arrays to MPMD arrays.

This function redistributes data from a Single Program Multiple Data (SPMD) layout to a Multiple Program Multiple Data (MPMD) layout. It handles memory constraints by grouping arrays and processing them in chunks. It’s the caller’s responsibility to not use the input spmd_arrays after calling this function as they will be consumed by this function.

The specs of the returned arrays will _not_ have mpmd_mesh.mpmd_axis_name in them.

Limitations: same constraints as jax.jit apply (e.g. _device_assignment must be the same for all arrays)

Args:

mpmd_mesh: The MPMD mesh definition. spmd_arrays: A pytree of source SPMD arrays. mpmd_shardings: A pytree of target MPMD shardings, matching the structure of

spmd_arrays.

threshold: Memory threshold in bytes for grouping operations.

If None, calculated based on available memory.

Returns:

A pytree of MpmdArray objects with the same structure as spmd_arrays.

jaxpp.api.treduce(
fun: Callable[[X], Y],
xs: X,
schedule: BaseSchedule,
axis: int = 0,
operation=(Concat(axis=0), AddT()),
) Y[source]#

Temporally reduces a sequence of inputs with a pipelined schedule.

This function behaves like the functional-programming primitive reduce applied along the leading (time / micro-batch) axis of xs. At each timestep i it applies fun to the slice xs[i] and combines the resulting values using operation, as shown in the following example:

def treduce(fun, xs, operation=(Concat(), Add())):
  # xs has shape (T, ...)
  state = tree_map(lambda a, op: op.state(len(xs), a),
                   fun(jnp.take(xs[0], 0, axis=axis)), operation)
  for i in range(1, len(xs)):
    state = tree_map(lambda op, s, v: op.update(s, v, i),
                     operation, state, fun(jnp.take(xs[i], i, axis=axis)))
  return state

Unlike a vanilla reduce, the execution of the loop body may be interleaved according to schedule so that different timesteps can overlap on the accelerator.

Args:
fun: A function that is applied to a single slice. It

receives one element xs[i] (a PyTree slice with the leading axis removed) and returns a PyTree.

xs: A PyTree whose leaf nodes are arrays. All leaf arrays must share

the same leading dimension size; this leading dimension is the axis that is reduced over.

schedule: A BaseSchedule specifying how loop

iterations should overlap.

axis: Integer specifying which axis of every leaf array in xs

represents the micro-batch dimension. At iteration i the function gathers the slice at index i along this axis and feeds it to fun.

operation: A PyTree of Op objects (default: default_op,

i.e., (Concat(), Add)) describing how the per-timestep values are aggregated. Each leaf Op object defines state and update methods. Convenient predefined ops include:

  • Add: Element-wise sum.

  • Max: Element-wise maximum.

  • Concat(axis=0): Stacks results from each timestep. The axis parameter specifies the dimension in the output array that corresponds to the reduction timesteps.

Returns:

A PyTree containing the aggregated result, with the same structure as Y.

jaxpp.api.treduce_i(
fun: Callable[[int], Y],
length: int,
schedule: BaseSchedule,
operation=(Concat(axis=0), AddT()),
) Y[source]#

Lower-level helper for treduce() that takes an explicit length.

Instead of slicing from a pre-materialised batch this variant invokes fun(i) directly for each 0 <= i < length and reduces the returned values using operation:

def treduce_i(fun, length, operation):
  state = tree_map(lambda a, op: op.state(length, a),
                   fun(0), operation)
  for i in range(1, length):
    state = tree_map(lambda op, s, v: op.update(s, v, i),
                     operation, state, fun(i))
  return state
Args:
fun: Function that receives the micro-batch/timestep index i (an

integer) and returns a PyTree to be reduced.

length: The number of timesteps / micro-batches. schedule: A BaseSchedule determining how

iterations may overlap.

operation: A PyTree of Op objects (default: default_op,

i.e., (Concat(), Add)) controlling the accumulation. See treduce() for details. Convenient predefined ops include:

  • Add: Element-wise sum.

  • Max: Element-wise maximum.

  • Concat(axis=0): Stacks results from each timestep. The axis parameter specifies the dimension in the output array that corresponds to the reduction timesteps.

Returns:

A PyTree containing the result of the temporal reduction, with the same structure as Y.