warp.tile_load_indexed#
- warp.tile_load_indexed(
- a: Array[Any],
- indices: Tile[int32, tuple[int]],
- shape: tuple[int, ...],
- offset: tuple[int, ...],
- axis: int32,
- storage: str,
Kernel
Differentiable
Loads a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.
- param a:
The source array in global memory
- param indices:
A 1D tile of integer indices mapping to elements in
a.- param shape:
Shape of the tile to load, must have the same number of dimensions as
a, and alongaxis, it must have the same number of elements as theindicestile.- param offset:
Offset in the source array to begin reading from (optional)
- param axis:
Axis of
athat indices refer to- param storage:
The storage location for the tile:
"register"for registers (default) or"shared"for shared memory.- returns:
A tile with shape as specified and data type the same as the source array
This example shows how to select and store the even indexed rows from a 2D array:
TILE_M = wp.constant(2) TILE_N = wp.constant(2) HALF_M = wp.constant(TILE_M // 2) HALF_N = wp.constant(TILE_N // 2) @wp.kernel def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)): i, j = wp.tid() evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2 t0 = wp.tile_load_indexed(x, indices=evens, shape=(HALF_M, TILE_N), offset=(i*TILE_M, j*TILE_N), axis=0, storage="register") wp.tile_store(y, t0, offset=(i*HALF_M, j*TILE_N)) M = TILE_M * 2 N = TILE_N * 2 arr = np.arange(M * N).reshape(M, N) x = wp.array(arr, dtype=float) y = wp.zeros((M // 2, N), dtype=float) wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device) print(x.numpy()) print(y.numpy())
Prints:
[[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.] [12. 13. 14. 15.]] [[ 0. 1. 2. 3.] [ 8. 9. 10. 11.]]