warp.tile_reduce#

warp.tile_reduce(
op: Callable,
a: Tile[Any, tuple[int, ...]],
) Tile[Any, tuple[1]]#
  • Kernel

Apply a custom reduction operator across the tile.

This function cooperatively performs a reduction using the provided operator across the tile.

param op:

A callable function that accepts two arguments and returns one argument, may be a user function or builtin

param a:

The input tile, the operator (or one of its overloads) must be able to accept the tile’s data type

returns:

A single-element tile with the same data type as the input tile.

Example:

@wp.kernel
def factorial():

    t = wp.tile_arange(1, 10, dtype=int)
    s = wp.tile_reduce(wp.mul, t)

    print(s)

wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)

Prints:

[362880] = tile(shape=(1), storage=register)
warp.tile_reduce(
op: Callable,
a: Tile[Scalar, tuple[int, ...]],
axis: int32,
) Tile[Scalar, tuple[int, ...]]
  • Kernel

Apply a custom reduction operator across a tile axis.

This function cooperatively performs a reduction using the provided operator across an axis of the tile.

param op:

A callable function that accepts two arguments and returns one argument, may be a user function or builtin

param a:

The input tile, the operator (or one of its overloads) must be able to accept the tile’s data type. Must reside in shared memory.

param axis:

The tile axis to perform the reduction across. Must be a compile-time constant.

returns:

A tile with the same shape as the input tile less the axis dimension and the same data type as the input tile.

Example:

TILE_M = wp.constant(4)
TILE_N = wp.constant(2)

@wp.kernel
def compute(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):

    a = wp.tile_load(x, shape=(TILE_M, TILE_N))
    b = wp.tile_reduce(wp.add, a, axis=1)
    wp.tile_store(y, b)

arr = np.arange(TILE_M * TILE_N).reshape(TILE_M, TILE_N)

x = wp.array(arr, dtype=float)
y = wp.zeros(TILE_M, dtype=float)

wp.launch_tiled(compute, dim=[1], inputs=[x], outputs=[y], block_dim=32)

print(x.numpy())
print(y.numpy())

Prints:

[[0. 1.]
 [2. 3.]
 [4. 5.]
 [6. 7.]]
[ 1.  5.  9. 13.]