warp.tile_sum#

warp.tile_sum(
a: Tile[Any, tuple[int, ...]],
axis: int32,
) Tile[Any, tuple[int, ...]]#
  • Kernel

  • Differentiable

Cooperatively compute the sum of the tile elements.

Reduce across a tile axis using all threads in the block.

Parameters:
  • a – The input tile. Must reside in shared memory.

  • axis – The tile axis to compute the sum across. Must be a compile-time constant.

Returns:

A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.

Example

@wp.kernel
def compute():

    t = wp.tile_ones(dtype=float, shape=(8, 8))
    s = wp.tile_sum(t, axis=0)

    print(s)

wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
[8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
warp.tile_sum(
a: Tile[Any, tuple[int, ...]],
) Tile[Any, tuple[Literal[1]]]
  • Kernel

  • Differentiable

Cooperatively compute the sum of the tile elements.

Reduce across all elements using all threads in the block.

Parameters:

a – The tile to compute the sum of

Returns:

A single-element tile holding the sum.

Example

@wp.kernel
def compute():

    t = wp.tile_ones(dtype=float, shape=(16, 16))
    s = wp.tile_sum(t)

    print(s)

wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
[256] = tile(shape=(1), storage=register)