warp.tile_sum#

warp.tile_sum(
a: Tile[Any, tuple[int, ...]],
axis: int32,
) Tile[Any, tuple[int, ...]]#
  • Kernel

  • Differentiable

Cooperatively compute the sum of the tile elements across an axis of the tile using all threads in the block.

param a:

The input tile. Must reside in shared memory.

param axis:

The tile axis to compute the sum across. Must be a compile-time constant.

returns:

A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.

Example:

@wp.kernel
def compute():

    t = wp.tile_ones(dtype=float, shape=(8, 8))
    s = wp.tile_sum(t, axis=0)

    print(s)

wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)

Prints:

[8 8 8 8 8 8 8 8] = tile(shape=(8), storage=register)
warp.tile_sum(
a: Tile[Any, tuple[int, ...]],
) Tile[Any, tuple[1]]
  • Kernel

  • Differentiable

Cooperatively compute the sum of the tile elements using all threads in the block.

param a:

The tile to compute the sum of

returns:

A single-element tile holding the sum

Example:

@wp.kernel
def compute():

    t = wp.tile_ones(dtype=float, shape=(16, 16))
    s = wp.tile_sum(t)

    print(s)

wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)

Prints:

[256] = tile(shape=(1), storage=register)