CUDA Parallel
Warning
Python exposure of parallel algorithms is in public beta. The API is subject to change without notice.
Algorithms
- cuda.parallel.experimental.algorithms.merge_sort(d_in_keys: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_in_items: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase | None, d_out_keys: cuda.parallel.experimental.typing.DeviceArrayLike, d_out_items: cuda.parallel.experimental.typing.DeviceArrayLike | None, op: Callable)
Implements a device-wide merge sort using
d_in_keys
and the comparison operatorop
.Example
Below,
merge_sort
is used to sort a sequence of keys inplace. It also rearranges the items according to the keys’ order.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def compare_op(lhs, rhs): return np.uint8(lhs < rhs) h_in_keys = np.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32") h_in_items = np.array( [-3.2, 2.2, 1.9, 4.0, -3.9, 2.7, 0, 8.3 - 1, 2.9, 5.4], dtype="float32" ) d_in_keys = cp.asarray(h_in_keys) d_in_items = cp.asarray(h_in_items) # Instantiate scan for the given keys, items, and operator merge_sort = algorithms.merge_sort( d_in_keys, d_in_items, d_in_keys, d_in_items, compare_op ) # Determine temporary device storage requirements temp_storage_size = merge_sort( None, d_in_keys, d_in_items, d_in_keys, d_in_items, d_in_keys.size ) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run merge_sort merge_sort( d_temp_storage, d_in_keys, d_in_items, d_in_keys, d_in_items, d_in_keys.size ) # Check the result is correct h_out_keys = cp.asnumpy(d_in_keys) h_out_items = cp.asnumpy(d_in_items) argsort = np.argsort(h_in_keys, stable=True) h_in_keys = np.array(h_in_keys)[argsort] h_in_items = np.array(h_in_items)[argsort] np.testing.assert_array_equal(h_out_keys, h_in_keys) np.testing.assert_array_equal(h_out_items, h_in_items)
- Parameters
d_in_keys – Device array or iterator containing the input keys to be sorted
d_in_items – Optional device array or iterator that contains each key’s corresponding item
d_in_keys – Device array to store the sorted keys
d_in_items – Device array to store the sorted items
op – Callable representing the comparison operator
- Returns
A callable object that can be used to perform the merge sort
- cuda.parallel.experimental.algorithms.reduce_into(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike, op: Callable, h_init: numpy.ndarray)
Computes a device-wide reduction using the specified binary
op
and initial valueinit
.Example
Below,
reduce_into
is used to compute the minimum value of a sequence of integers.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def min_op(a, b): return a if a < b else b dtype = np.int32 h_init = np.array([42], dtype=dtype) d_input = cp.array([8, 6, 7, 5, 3, 0, 9], dtype=dtype) d_output = cp.empty(1, dtype=dtype) # Instantiate reduction for the given operator and initial value reduce_into = algorithms.reduce_into(d_output, d_output, min_op, h_init) # Determine temporary device storage requirements temp_storage_size = reduce_into(None, d_input, d_output, len(d_input), h_init) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, d_input, d_output, len(d_input), h_init) # Check the result is correct expected_output = 0 assert (d_output == expected_output).all()
- Parameters
d_in – Device array or iterator containing the input sequence of data items
d_out – Device array (of size 1) that will store the result of the reduction
op – Callable representing the binary operator to apply
init – Numpy array storing initial value of the reduction
- Returns
A callable object that can be used to perform the reduction
- cuda.parallel.experimental.algorithms.scan(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable, h_init: numpy.ndarray)
Computes a device-wide scan using the specified binary
op
and initial valueinit
.Example
Below,
scan
is used to compute an exclusive scan of a sequence of integers.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def max_op(a, b): return max(a, b) h_init = np.array([1], dtype="int32") d_input = cp.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32") d_output = cp.empty_like(d_input, dtype="int32") # Instantiate scan for the given operator and initial value scanner = algorithms.scan(d_output, d_output, max_op, h_init) # Determine temporary device storage requirements temp_storage_size = scanner(None, d_input, d_output, d_input.size, h_init) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction scanner(d_temp_storage, d_input, d_output, d_input.size, h_init) # Check the result is correct expected = np.asarray([1, 1, 1, 2, 2, 2, 4, 4, 4, 4]) np.testing.assert_equal(d_output.get(), expected)
- Parameters
d_in – Device array or iterator containing the input sequence of data items
d_out – Device array that will store the result of the scan
op – Callable representing the binary operator to apply
init – Numpy array storing initial value of the scan
- Returns
A callable object that can be used to perform the scan
- cuda.parallel.experimental.algorithms.segmented_reduce(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike, start_offsets_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, end_offsets_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable, h_init: numpy.ndarray)
Computes a device-wide segmented reduction using the specified binary
op
and initial valueinit
.Example
Below,
segmented_reduce
is used to compute the minimum value of a sequence of integers.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def min_op(a, b): return a if a < b else b dtype = np.dtype(np.int32) max_val = np.iinfo(dtype).max h_init = np.asarray(max_val, dtype=dtype) offsets = cp.array([0, 7, 11, 16], dtype=np.int64) first_segment = (8, 6, 7, 5, 3, 0, 9) second_segment = (-4, 3, 0, 1) third_segment = (3, 1, 11, 25, 8) d_input = cp.array( [*first_segment, *second_segment, *third_segment], dtype=dtype, ) start_o = offsets[:-1] end_o = offsets[1:] n_segments = start_o.size d_output = cp.empty(n_segments, dtype=dtype) # Instantiate reduction for the given operator and initial value segmented_reduce = algorithms.segmented_reduce( d_output, d_output, start_o, end_o, min_op, h_init ) # Determine temporary device storage requirements temp_storage_size = segmented_reduce( None, d_input, d_output, n_segments, start_o, end_o, h_init ) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction segmented_reduce( d_temp_storage, d_input, d_output, n_segments, start_o, end_o, h_init ) # Check the result is correct expected_output = cp.asarray([0, -4, 1], dtype=d_output.dtype) assert (d_output == expected_output).all()
- Parameters
d_in – Device array or iterator containing the input sequence of data items
d_out – Device array that will store the result of the reduction
start_offsets_in – Device array or iterator containing offsets to start of segments
end_offsets_in – Device array or iterator containing offsets to end of segments
op – Callable representing the binary operator to apply
init – Numpy array storing initial value of the reduction
- Returns
A callable object that can be used to perform the reduction
Iterators
- cuda.parallel.experimental.iterators.CacheModifiedInputIterator(device_array, modifier, prefix='')
Random Access Cache Modified Iterator that wraps a native device pointer.
Similar to https://nvidia.github.io/cccl/cub/api/classcub_1_1CacheModifiedInputIterator.html
Currently the only supported modifier is “stream” (LOAD_CS).
Example
The code snippet below demonstrates the usage of a
CacheModifiedInputIterator
:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b values = [8, 6, 7, 5, 3, 0, 9] d_input = cp.array(values, dtype=np.int32) d_output = cp.empty(1, dtype=np.int32) iterator = iterators.CacheModifiedInputIterator( d_input, modifier="stream" ) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(iterator, d_output, add_op, h_init) temp_storage_size = reduce_into(None, iterator, d_output, len(values), h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, iterator, d_output, len(values), h_init) expected_output = functools.reduce(lambda a, b: a + b, values) assert (d_output == expected_output).all()
- Parameters
device_array – CUDA device array storing the input sequence of data items
modifier – The PTX cache load modifier
prefix – An optional prefix added to the iterator’s methods to prevent name collisions.
- Returns
A
CacheModifiedInputIterator
object initialized withdevice_array
- cuda.parallel.experimental.iterators.ConstantIterator(value)
Returns an Iterator representing a sequence of constant values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1constant__iterator.html
Example
The code snippet below demonstrates the usage of a
ConstantIterator
representing the sequence[10, 10, 10]
:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b value = 10 num_items = 3 constant_it = iterators.ConstantIterator(np.int32(value)) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(constant_it, d_output, add_op, h_init) temp_storage_size = reduce_into(None, constant_it, d_output, num_items, h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, constant_it, d_output, num_items, h_init) expected_output = functools.reduce(lambda a, b: a + b, [value] * num_items) assert (d_output == expected_output).all()
- Parameters
value – The value of every item in the sequence
- Returns
A
ConstantIterator
object initialized tovalue
- cuda.parallel.experimental.iterators.CountingIterator(offset)
Returns an Iterator representing a sequence of incrementing values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1counting__iterator.html
Example
The code snippet below demonstrates the usage of a
CountingIterator
representing the sequence[10, 11, 12]
:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b first_item = 10 num_items = 3 first_it = iterators.CountingIterator(np.int32(first_item)) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(first_it, d_output, add_op, h_init) temp_storage_size = reduce_into(None, first_it, d_output, num_items, h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, first_it, d_output, num_items, h_init) expected_output = functools.reduce( lambda a, b: a + b, range(first_item, first_item + num_items) ) assert (d_output == expected_output).all()
- Parameters
offset – The initial value of the sequence
- Returns
A
CountingIterator
object initialized tooffset
- cuda.parallel.experimental.iterators.TransformIterator(it, op)
Returns an Iterator representing a transformed sequence of values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1transform__iterator.html
Example
The code snippet below demonstrates the usage of a
TransformIterator
composed with aCountingIterator
, transforming the sequence[10, 11, 12]
by squaring each item before reducing the output:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b def square_op(a): return a**2 first_item = 10 num_items = 3 transform_it = iterators.TransformIterator( iterators.CountingIterator(np.int32(first_item)), square_op ) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(transform_it, d_output, add_op, h_init) temp_storage_size = reduce_into(None, transform_it, d_output, num_items, h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, transform_it, d_output, num_items, h_init) expected_output = functools.reduce( lambda a, b: a + b, [a**2 for a in range(first_item, first_item + num_items)] ) assert (d_output == expected_output).all()
- Parameters
it – The iterator object to be transformed
op – The transform operation
- Returns
A
TransformIterator
object to transform the items init
usingop
Utilities
- cuda.parallel.experimental.struct.gpu_struct(this: type) Type[Any]
Defines the given class as being a GpuStruct.
A GpuStruct represents a value composed of one or more other values, and is defined as a class with annotated fields (similar to a dataclass). The type of each field must be a subclass of np.number, like np.int32 or np.float64.
Arrays of GPUStruct objects can be used as inputs to cuda.parallel algorithms.
Example
The code snippet below shows how to use gpu_struct to define a MinMax type (composed of min_val, max_val values), and perform a reduction on an input array of floating point values to compute its the smallest and the largest absolute values:
import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators from cuda.parallel.experimental.struct import gpu_struct @gpu_struct class MinMax: min_val: np.float64 max_val: np.float64 def minmax_op(v1: MinMax, v2: MinMax): c_min = min(v1.min_val, v2.min_val) c_max = max(v1.max_val, v2.max_val) return MinMax(c_min, c_max) def transform_op(v): av = abs(v) return MinMax(av, av) nelems = 4096 d_in = cp.random.randn(nelems) # input values must be transformed to MinMax structures # in-place to map computation to data-parallel reduction # algorithm that requires commutative binary operation # with both operands having the same type. tr_it = iterators.TransformIterator(d_in, transform_op) d_out = cp.empty(tuple(), dtype=MinMax.dtype) # initial value set with identity elements of # minimum and maximum operators h_init = MinMax(np.inf, -np.inf) # get algorithm object cccl_sum = algorithms.reduce_into(tr_it, d_out, minmax_op, h_init) # allocated needed temporary tmp_sz = cccl_sum(None, tr_it, d_out, nelems, h_init) tmp_storage = cp.empty(tmp_sz, dtype=cp.uint8) # invoke the reduction algorithm cccl_sum(tmp_storage, tr_it, d_out, nelems, h_init) # display values computed on the device actual = d_out.get() h = np.abs(d_in.get()) expected = np.asarray([(h.min(), h.max())], dtype=MinMax.dtype) assert actual == expected