CUDA Parallel

Warning

Python exposure of parallel algorithms is in public beta. The API is subject to change without notice.

Algorithms

cuda.parallel.experimental.algorithms.merge_sort(d_in_keys: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_in_items: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase | None, d_out_keys: cuda.parallel.experimental.typing.DeviceArrayLike, d_out_items: cuda.parallel.experimental.typing.DeviceArrayLike | None, op: Callable)

Implements a device-wide merge sort using d_in_keys and the comparison operator op.

Example

Below, merge_sort is used to sort a sequence of keys inplace. It also rearranges the items according to the keys’ order.

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms

def compare_op(lhs, rhs):
    return np.uint8(lhs < rhs)

h_in_keys = np.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32")
h_in_items = np.array(
    [-3.2, 2.2, 1.9, 4.0, -3.9, 2.7, 0, 8.3 - 1, 2.9, 5.4], dtype="float32"
)

d_in_keys = cp.asarray(h_in_keys)
d_in_items = cp.asarray(h_in_items)

# Instantiate scan for the given keys, items, and operator
merge_sort = algorithms.merge_sort(
    d_in_keys, d_in_items, d_in_keys, d_in_items, compare_op
)

# Determine temporary device storage requirements
temp_storage_size = merge_sort(
    None, d_in_keys, d_in_items, d_in_keys, d_in_items, d_in_keys.size
)

# Allocate temporary storage
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run merge_sort
merge_sort(
    d_temp_storage, d_in_keys, d_in_items, d_in_keys, d_in_items, d_in_keys.size
)

# Check the result is correct
h_out_keys = cp.asnumpy(d_in_keys)
h_out_items = cp.asnumpy(d_in_items)

argsort = np.argsort(h_in_keys, stable=True)
h_in_keys = np.array(h_in_keys)[argsort]
h_in_items = np.array(h_in_items)[argsort]

np.testing.assert_array_equal(h_out_keys, h_in_keys)
np.testing.assert_array_equal(h_out_items, h_in_items)
Parameters
  • d_in_keys – Device array or iterator containing the input keys to be sorted

  • d_in_items – Optional device array or iterator that contains each key’s corresponding item

  • d_in_keys – Device array to store the sorted keys

  • d_in_items – Device array to store the sorted items

  • op – Callable representing the comparison operator

Returns

A callable object that can be used to perform the merge sort

cuda.parallel.experimental.algorithms.reduce_into(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike, op: Callable, h_init: numpy.ndarray)

Computes a device-wide reduction using the specified binary op and initial value init.

Example

Below, reduce_into is used to compute the minimum value of a sequence of integers.

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms

def min_op(a, b):
    return a if a < b else b

dtype = np.int32
h_init = np.array([42], dtype=dtype)
d_input = cp.array([8, 6, 7, 5, 3, 0, 9], dtype=dtype)
d_output = cp.empty(1, dtype=dtype)

# Instantiate reduction for the given operator and initial value
reduce_into = algorithms.reduce_into(d_output, d_output, min_op, h_init)

# Determine temporary device storage requirements
temp_storage_size = reduce_into(None, d_input, d_output, len(d_input), h_init)

# Allocate temporary storage
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, d_input, d_output, len(d_input), h_init)

# Check the result is correct
expected_output = 0
assert (d_output == expected_output).all()
Parameters
  • d_in – Device array or iterator containing the input sequence of data items

  • d_out – Device array (of size 1) that will store the result of the reduction

  • op – Callable representing the binary operator to apply

  • init – Numpy array storing initial value of the reduction

Returns

A callable object that can be used to perform the reduction

cuda.parallel.experimental.algorithms.scan(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable, h_init: numpy.ndarray)

Computes a device-wide scan using the specified binary op and initial value init.

Example

Below, scan is used to compute an exclusive scan of a sequence of integers.

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms

def max_op(a, b):
    return max(a, b)

h_init = np.array([1], dtype="int32")
d_input = cp.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32")
d_output = cp.empty_like(d_input, dtype="int32")

# Instantiate scan for the given operator and initial value
scanner = algorithms.scan(d_output, d_output, max_op, h_init)

# Determine temporary device storage requirements
temp_storage_size = scanner(None, d_input, d_output, d_input.size, h_init)

# Allocate temporary storage
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
scanner(d_temp_storage, d_input, d_output, d_input.size, h_init)

# Check the result is correct
expected = np.asarray([1, 1, 1, 2, 2, 2, 4, 4, 4, 4])
np.testing.assert_equal(d_output.get(), expected)
Parameters
  • d_in – Device array or iterator containing the input sequence of data items

  • d_out – Device array that will store the result of the scan

  • op – Callable representing the binary operator to apply

  • init – Numpy array storing initial value of the scan

Returns

A callable object that can be used to perform the scan

cuda.parallel.experimental.algorithms.segmented_reduce(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike, start_offsets_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, end_offsets_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable, h_init: numpy.ndarray)

Computes a device-wide segmented reduction using the specified binary op and initial value init.

Example

Below, segmented_reduce is used to compute the minimum value of a sequence of integers.

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms

def min_op(a, b):
    return a if a < b else b

dtype = np.dtype(np.int32)
max_val = np.iinfo(dtype).max
h_init = np.asarray(max_val, dtype=dtype)

offsets = cp.array([0, 7, 11, 16], dtype=np.int64)
first_segment = (8, 6, 7, 5, 3, 0, 9)
second_segment = (-4, 3, 0, 1)
third_segment = (3, 1, 11, 25, 8)
d_input = cp.array(
    [*first_segment, *second_segment, *third_segment],
    dtype=dtype,
)

start_o = offsets[:-1]
end_o = offsets[1:]

n_segments = start_o.size
d_output = cp.empty(n_segments, dtype=dtype)

# Instantiate reduction for the given operator and initial value
segmented_reduce = algorithms.segmented_reduce(
    d_output, d_output, start_o, end_o, min_op, h_init
)

# Determine temporary device storage requirements
temp_storage_size = segmented_reduce(
    None, d_input, d_output, n_segments, start_o, end_o, h_init
)

# Allocate temporary storage
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
segmented_reduce(
    d_temp_storage, d_input, d_output, n_segments, start_o, end_o, h_init
)

# Check the result is correct
expected_output = cp.asarray([0, -4, 1], dtype=d_output.dtype)
assert (d_output == expected_output).all()
Parameters
  • d_in – Device array or iterator containing the input sequence of data items

  • d_out – Device array that will store the result of the reduction

  • start_offsets_in – Device array or iterator containing offsets to start of segments

  • end_offsets_in – Device array or iterator containing offsets to end of segments

  • op – Callable representing the binary operator to apply

  • init – Numpy array storing initial value of the reduction

Returns

A callable object that can be used to perform the reduction

Iterators

cuda.parallel.experimental.iterators.CacheModifiedInputIterator(device_array, modifier, prefix='')

Random Access Cache Modified Iterator that wraps a native device pointer.

Similar to https://nvidia.github.io/cccl/cub/api/classcub_1_1CacheModifiedInputIterator.html

Currently the only supported modifier is “stream” (LOAD_CS).

Example

The code snippet below demonstrates the usage of a CacheModifiedInputIterator:

import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
    return a + b

values = [8, 6, 7, 5, 3, 0, 9]
d_input = cp.array(values, dtype=np.int32)
d_output = cp.empty(1, dtype=np.int32)

iterator = iterators.CacheModifiedInputIterator(
    d_input, modifier="stream"
)  # Input sequence
h_init = np.array([0], dtype=np.int32)  # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32)  # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(iterator, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, iterator, d_output, len(values), h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, iterator, d_output, len(values), h_init)

expected_output = functools.reduce(lambda a, b: a + b, values)
assert (d_output == expected_output).all()
Parameters
  • device_array – CUDA device array storing the input sequence of data items

  • modifier – The PTX cache load modifier

  • prefix – An optional prefix added to the iterator’s methods to prevent name collisions.

Returns

A CacheModifiedInputIterator object initialized with device_array

cuda.parallel.experimental.iterators.ConstantIterator(value)

Returns an Iterator representing a sequence of constant values.

Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1constant__iterator.html

Example

The code snippet below demonstrates the usage of a ConstantIterator representing the sequence [10, 10, 10]:

import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
    return a + b

value = 10
num_items = 3

constant_it = iterators.ConstantIterator(np.int32(value))  # Input sequence
h_init = np.array([0], dtype=np.int32)  # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32)  # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(constant_it, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, constant_it, d_output, num_items, h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, constant_it, d_output, num_items, h_init)

expected_output = functools.reduce(lambda a, b: a + b, [value] * num_items)
assert (d_output == expected_output).all()
Parameters

value – The value of every item in the sequence

Returns

A ConstantIterator object initialized to value

cuda.parallel.experimental.iterators.CountingIterator(offset)

Returns an Iterator representing a sequence of incrementing values.

Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1counting__iterator.html

Example

The code snippet below demonstrates the usage of a CountingIterator representing the sequence [10, 11, 12]:

import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
    return a + b

first_item = 10
num_items = 3

first_it = iterators.CountingIterator(np.int32(first_item))  # Input sequence
h_init = np.array([0], dtype=np.int32)  # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32)  # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(first_it, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, first_it, d_output, num_items, h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, first_it, d_output, num_items, h_init)

expected_output = functools.reduce(
    lambda a, b: a + b, range(first_item, first_item + num_items)
)
assert (d_output == expected_output).all()
Parameters

offset – The initial value of the sequence

Returns

A CountingIterator object initialized to offset

cuda.parallel.experimental.iterators.TransformIterator(it, op)

Returns an Iterator representing a transformed sequence of values.

Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1transform__iterator.html

Example

The code snippet below demonstrates the usage of a TransformIterator composed with a CountingIterator, transforming the sequence [10, 11, 12] by squaring each item before reducing the output:

import functools

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators

def add_op(a, b):
    return a + b

def square_op(a):
    return a**2

first_item = 10
num_items = 3

transform_it = iterators.TransformIterator(
    iterators.CountingIterator(np.int32(first_item)), square_op
)  # Input sequence
h_init = np.array([0], dtype=np.int32)  # Initial value for the reduction
d_output = cp.empty(1, dtype=np.int32)  # Storage for output

# Instantiate reduction, determine storage requirements, and allocate storage
reduce_into = algorithms.reduce_into(transform_it, d_output, add_op, h_init)
temp_storage_size = reduce_into(None, transform_it, d_output, num_items, h_init)
d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8)

# Run reduction
reduce_into(d_temp_storage, transform_it, d_output, num_items, h_init)

expected_output = functools.reduce(
    lambda a, b: a + b, [a**2 for a in range(first_item, first_item + num_items)]
)
assert (d_output == expected_output).all()
Parameters
  • it – The iterator object to be transformed

  • op – The transform operation

Returns

A TransformIterator object to transform the items in it using op

Utilities

cuda.parallel.experimental.struct.gpu_struct(this: type) Type[Any]

Defines the given class as being a GpuStruct.

A GpuStruct represents a value composed of one or more other values, and is defined as a class with annotated fields (similar to a dataclass). The type of each field must be a subclass of np.number, like np.int32 or np.float64.

Arrays of GPUStruct objects can be used as inputs to cuda.parallel algorithms.

Example

The code snippet below shows how to use gpu_struct to define a MinMax type (composed of min_val, max_val values), and perform a reduction on an input array of floating point values to compute its the smallest and the largest absolute values:

import cupy as cp
import numpy as np

import cuda.parallel.experimental.algorithms as algorithms
import cuda.parallel.experimental.iterators as iterators
from cuda.parallel.experimental.struct import gpu_struct

@gpu_struct
class MinMax:
    min_val: np.float64
    max_val: np.float64

def minmax_op(v1: MinMax, v2: MinMax):
    c_min = min(v1.min_val, v2.min_val)
    c_max = max(v1.max_val, v2.max_val)
    return MinMax(c_min, c_max)

def transform_op(v):
    av = abs(v)
    return MinMax(av, av)

nelems = 4096

d_in = cp.random.randn(nelems)
# input values must be transformed to MinMax structures
# in-place to map computation to data-parallel reduction
# algorithm that requires commutative binary operation
# with both operands having the same type.
tr_it = iterators.TransformIterator(d_in, transform_op)

d_out = cp.empty(tuple(), dtype=MinMax.dtype)

# initial value set with identity elements of
# minimum and maximum operators
h_init = MinMax(np.inf, -np.inf)

# get algorithm object
cccl_sum = algorithms.reduce_into(tr_it, d_out, minmax_op, h_init)

# allocated needed temporary
tmp_sz = cccl_sum(None, tr_it, d_out, nelems, h_init)
tmp_storage = cp.empty(tmp_sz, dtype=cp.uint8)

# invoke the reduction algorithm
cccl_sum(tmp_storage, tr_it, d_out, nelems, h_init)

# display values computed on the device
actual = d_out.get()

h = np.abs(d_in.get())
expected = np.asarray([(h.min(), h.max())], dtype=MinMax.dtype)

assert actual == expected