warp.tile_load_indexed#

warp.tile_load_indexed(
a: Array[Any],
indices: Tile[int32, tuple[int]],
shape: tuple[int, ...],
offset: tuple[int, ...],
axis: int32,
storage: str,
) Tile[Any, tuple[int, ...]]#
  • Kernel

  • Differentiable

Load a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.

Parameters:
  • a – The source array in global memory

  • indices – A 1D tile of integer indices mapping to elements in a.

  • shape – Shape of the tile to load, must have the same number of dimensions as a, and along axis, it must have the same number of elements as the indices tile.

  • offset – Offset in the source array to begin reading from (optional)

  • axis – Axis of a that indices refer to

  • storage – The storage location for the tile: "register" for registers (default) or "shared" for shared memory.

Returns:

A tile with shape as specified and data type the same as the source array.

Example

This example shows how to select and store the even indexed rows from a 2D array.

TILE_M = wp.constant(2)
TILE_N = wp.constant(2)
HALF_M = wp.constant(TILE_M // 2)
HALF_N = wp.constant(TILE_N // 2)

@wp.kernel
def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
    i, j = wp.tid()

    evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2

    t0 = wp.tile_load_indexed(x, indices=evens, shape=(HALF_M, TILE_N), offset=(i*TILE_M, j*TILE_N), axis=0, storage="register")
    wp.tile_store(y, t0, offset=(i*HALF_M, j*TILE_N))

M = TILE_M * 2
N = TILE_N * 2

arr = np.arange(M * N).reshape(M, N)

x = wp.array(arr, dtype=float)
y = wp.zeros((M // 2, N), dtype=float)

wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device)

print(x.numpy())
print(y.numpy())
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]
 12. 13. 14. 15.]]

[[ 0.  1.  2.  3.]
 [ 8.  9. 10. 11.]]