warp.tile_store_indexed#

warp.tile_store_indexed(
a: Array[Any],
indices: Tile[int32, tuple[int]],
t: Tile[Any, tuple[int, ...]],
offset: tuple[int, ...],
axis: int32,
) None#
  • Kernel

  • Differentiable

Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.

Parameters:
  • a – The destination array in global memory

  • indices – A 1D tile of integer indices mapping to elements in a.

  • t – The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along axis, it must have the same number of elements as the indices tile.

  • offset – Offset in the destination array (optional)

  • axis – Axis of a that indices refer to.

Example

This example shows how to map tile rows to the even rows of a 2D array.

TILE_M = wp.constant(2)
TILE_N = wp.constant(2)
TWO_M = wp.constant(TILE_M * 2)
TWO_N = wp.constant(TILE_N * 2)

@wp.kernel
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
    i, j = wp.tid()

    t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register")

    evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2

    wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i*TWO_M, j*TILE_N), axis=0)

M = TILE_M * 2
N = TILE_N * 2

arr = np.arange(M * N, dtype=float).reshape(M, N)

x = wp.array(arr, dtype=float, requires_grad=True, device=device)
y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device)

wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)

print(x.numpy())
print(y.numpy())
[[ 0.  1.  2.  3.]
    [ 4.  5.  6.  7.]
    [ 8.  9. 10. 11.]
    [12. 13. 14. 15.]]

[[ 0.  1.  2.  3.]
    [ 0.  0.  0.  0.]
    [ 4.  5.  6.  7.]
    [ 0.  0.  0.  0.]
    [ 8.  9. 10. 11.]
    [ 0.  0.  0.  0.]
    [12. 13. 14. 15.]
    [ 0.  0.  0.  0.]]