warp.tile_store_indexed#

warp.tile_store_indexed(
a: Array[Any],
indices: Tile[int32, tuple[int]],
t: Tile[Any, tuple[int, ...]],
offset: tuple[int, ...],
axis: int32,
) None#
  • Kernel

  • Differentiable

Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.

param a:

The destination array in global memory

param indices:

A 1D tile of integer indices mapping to elements in a.

param t:

The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along axis, it must have the same number of elements as the indices tile.

param offset:

Offset in the destination array (optional)

param axis:

Axis of a that indices refer to

This example shows how to map tile rows to the even rows of a 2D array:

TILE_M = wp.constant(2)
TILE_N = wp.constant(2)
TWO_M = wp.constant(TILE_M * 2)
TWO_N = wp.constant(TILE_N * 2)

@wp.kernel
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
    i, j = wp.tid()

    t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")

    evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2

    wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i*TWO_M, j*TILE_N), axis=0)

M = TILE_M * 2
N = TILE_N * 2

arr = np.arange(M * N, dtype=float).reshape(M, N)

x = wp.array(arr, dtype=float, requires_grad=True, device=device)
y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device)

wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)

print(x.numpy())
print(y.numpy())

Prints:

[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]]

[[ 0.  1.  2.  3.]
 [ 0.  0.  0.  0.]
 [ 4.  5.  6.  7.]
 [ 0.  0.  0.  0.]
 [ 8.  9. 10. 11.]
 [ 0.  0.  0.  0.]
 [12. 13. 14. 15.]
 [ 0.  0.  0.  0.]]