warp.tile_load_indexed#

warp.tile_load_indexed(
a: Array[Any],
indices: Tile[int32, tuple[int]],
shape: tuple[int, ...],
offset: tuple[int, ...],
axis: int32,
storage: str,
) Tile[Any, tuple[int, ...]]#
  • Kernel

  • Differentiable

Loads a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.

param a:

The source array in global memory

param indices:

A 1D tile of integer indices mapping to elements in a.

param shape:

Shape of the tile to load, must have the same number of dimensions as a, and along axis, it must have the same number of elements as the indices tile.

param offset:

Offset in the source array to begin reading from (optional)

param axis:

Axis of a that indices refer to

param storage:

The storage location for the tile: "register" for registers (default) or "shared" for shared memory.

returns:

A tile with shape as specified and data type the same as the source array

This example shows how to select and store the even indexed rows from a 2D array:

TILE_M = wp.constant(2)
TILE_N = wp.constant(2)
HALF_M = wp.constant(TILE_M // 2)
HALF_N = wp.constant(TILE_N // 2)

@wp.kernel
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
    i, j = wp.tid()

    evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2

    t0 = wp.tile_load_indexed(x, indices=evens, shape=(HALF_M, TILE_N), offset=(i*TILE_M, j*TILE_N), axis=0, storage="register")
    wp.tile_store(y, t0, offset=(i*HALF_M, j*TILE_N))

M = TILE_M * 2
N = TILE_N * 2

arr = np.arange(M * N).reshape(M, N)

x = wp.array(arr, dtype=float)
y = wp.zeros((M // 2, N), dtype=float)

wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)

print(x.numpy())
print(y.numpy())

Prints:

[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 [12. 13. 14. 15.]]

[[ 0.  1.  2.  3.]
 [ 8.  9. 10. 11.]]