warp.tile_store_indexed#
- warp.tile_store_indexed(
- a: Array[Any],
- indices: Tile[int32, tuple[int]],
- t: Tile[Any, tuple[int, ...]],
- offset: tuple[int, ...],
- axis: int32,
Kernel
Differentiable
Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
- param a:
The destination array in global memory
- param indices:
A 1D tile of integer indices mapping to elements in
a.- param t:
The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along
axis, it must have the same number of elements as theindicestile.- param offset:
Offset in the destination array (optional)
- param axis:
Axis of
athat indices refer to
This example shows how to map tile rows to the even rows of a 2D array:
TILE_M = wp.constant(2) TILE_N = wp.constant(2) TWO_M = wp.constant(TILE_M * 2) TWO_N = wp.constant(TILE_N * 2) @wp.kernel def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)): i, j = wp.tid() t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i*TILE_M, j*TILE_N), storage="register") evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2 wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i*TWO_M, j*TILE_N), axis=0) M = TILE_M * 2 N = TILE_N * 2 arr = np.arange(M * N, dtype=float).reshape(M, N) x = wp.array(arr, dtype=float, requires_grad=True, device=device) y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device) wp.launch_tiled(compute, dim=[2,2], inputs=[x], outputs=[y], block_dim=32, device=device) print(x.numpy()) print(y.numpy())
Prints:
[[ 0. 1. 2. 3.] [ 4. 5. 6. 7.] [ 8. 9. 10. 11.] [12. 13. 14. 15.]] [[ 0. 1. 2. 3.] [ 0. 0. 0. 0.] [ 4. 5. 6. 7.] [ 0. 0. 0. 0.] [ 8. 9. 10. 11.] [ 0. 0. 0. 0.] [12. 13. 14. 15.] [ 0. 0. 0. 0.]]