CUDA Parallel
Warning
Python exposure of parallel algorithms is in public beta. The API is subject to change without notice.
Algorithms
- cuda.parallel.experimental.algorithms.binary_transform(d_in1: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_in2: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable)
Apply a transformation to the given pair of input sequences according to the binary operation
op
.Example
import numpy as np def op(a, b): return a + b d_in1 = input_array d_in2 = input_array d_out = cp.empty_like(d_in1) binary_transform_device(d_in1, d_in2, d_out, len(d_in1), op) got = d_out.get() expected = binary_transform_host(d_in1.get(), d_in2.get(), op) np.testing.assert_allclose(expected, got, rtol=1e-5)
- Parameters
d_in1 – Device array or iterator containing the first input sequence of data items.
d_in2 – Device array or iterator containing the second input sequence of data items.
d_out – Device array or iterator to store the result of the transformation.
op – Binary operation to apply to each pair of items from the input sequences.
- Returns
A callable that performs the transformation.
- cuda.parallel.experimental.algorithms.exclusive_scan(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable, h_init: numpy.ndarray)
Computes a device-wide scan using the specified binary
op
and initial valueinit
.Example
Below,
scan
is used to compute an exclusive scan of a sequence of integers.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def max_op(a, b): return max(a, b) h_init = np.array([1], dtype="int32") d_input = cp.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32") d_output = cp.empty_like(d_input, dtype="int32") # Instantiate scan for the given operator and initial value scanner = algorithms.exclusive_scan(d_output, d_output, max_op, h_init) # Determine temporary device storage requirements temp_storage_size = scanner(None, d_input, d_output, d_input.size, h_init) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction scanner(d_temp_storage, d_input, d_output, d_input.size, h_init) # Check the result is correct expected = np.asarray([1, 1, 1, 2, 2, 2, 4, 4, 4, 4]) np.testing.assert_equal(d_output.get(), expected)
- Parameters
d_in – Device array or iterator containing the input sequence of data items
d_out – Device array that will store the result of the scan
op – Callable representing the binary operator to apply
init – Numpy array storing initial value of the scan
- Returns
A callable object that can be used to perform the scan
- cuda.parallel.experimental.algorithms.inclusive_scan(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable, h_init: numpy.ndarray)
Computes a device-wide scan using the specified binary
op
and initial valueinit
.Example
Below,
scan
is used to compute an inclusive scan of a sequence of integers.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def add_op(a, b): return a + b h_init = np.array([0], dtype="int32") d_input = cp.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32") d_output = cp.empty_like(d_input, dtype="int32") # Instantiate scan for the given operator and initial value scanner = algorithms.inclusive_scan(d_output, d_output, add_op, h_init) # Determine temporary device storage requirements temp_storage_size = scanner(None, d_input, d_output, d_input.size, h_init) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction scanner(d_temp_storage, d_input, d_output, d_input.size, h_init) # Check the result is correct expected = np.asarray([-5, -5, -3, -6, -4, 0, 0, -1, 1, 9]) np.testing.assert_equal(d_output.get(), expected)
- Parameters
d_in – Device array or iterator containing the input sequence of data items
d_out – Device array that will store the result of the scan
op – Callable representing the binary operator to apply
init – Numpy array storing initial value of the scan
- Returns
A callable object that can be used to perform the scan
- cuda.parallel.experimental.algorithms.merge_sort(d_in_keys: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_in_items: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase | None, d_out_keys: cuda.parallel.experimental.typing.DeviceArrayLike, d_out_items: cuda.parallel.experimental.typing.DeviceArrayLike | None, op: Callable)
Implements a device-wide merge sort using
d_in_keys
and the comparison operatorop
.Example
Below,
merge_sort
is used to sort a sequence of keys inplace. It also rearranges the items according to the keys’ order.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def compare_op(lhs, rhs): return np.uint8(lhs < rhs) h_in_keys = np.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32") h_in_items = np.array( [-3.2, 2.2, 1.9, 4.0, -3.9, 2.7, 0, 8.3 - 1, 2.9, 5.4], dtype="float32" ) d_in_keys = cp.asarray(h_in_keys) d_in_items = cp.asarray(h_in_items) # Instantiate merge_sort for the given keys, items, and operator merge_sort = algorithms.merge_sort( d_in_keys, d_in_items, d_in_keys, d_in_items, compare_op ) # Determine temporary device storage requirements temp_storage_size = merge_sort( None, d_in_keys, d_in_items, d_in_keys, d_in_items, d_in_keys.size ) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run merge_sort merge_sort( d_temp_storage, d_in_keys, d_in_items, d_in_keys, d_in_items, d_in_keys.size ) # Check the result is correct h_out_keys = cp.asnumpy(d_in_keys) h_out_items = cp.asnumpy(d_in_items) argsort = np.argsort(h_in_keys, stable=True) h_in_keys = np.array(h_in_keys)[argsort] h_in_items = np.array(h_in_items)[argsort] np.testing.assert_array_equal(h_out_keys, h_in_keys) np.testing.assert_array_equal(h_out_items, h_in_items)
- Parameters
d_in_keys – Device array or iterator containing the input keys to be sorted
d_in_items – Optional device array or iterator that contains each key’s corresponding item
d_out_keys – Device array to store the sorted keys
d_out_items – Device array to store the sorted items
op – Callable representing the comparison operator
- Returns
A callable object that can be used to perform the merge sort
- cuda.parallel.experimental.algorithms.reduce_into(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike, op: Callable, h_init: numpy.ndarray)
Computes a device-wide reduction using the specified binary
op
and initial valueinit
.Example
Below,
reduce_into
is used to compute the minimum value of a sequence of integers.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def min_op(a, b): return a if a < b else b dtype = np.int32 h_init = np.array([42], dtype=dtype) d_input = cp.array([8, 6, 7, 5, 3, 0, 9], dtype=dtype) d_output = cp.empty(1, dtype=dtype) # Instantiate reduction for the given operator and initial value reduce_into = algorithms.reduce_into(d_output, d_output, min_op, h_init) # Determine temporary device storage requirements temp_storage_size = reduce_into(None, d_input, d_output, len(d_input), h_init) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, d_input, d_output, len(d_input), h_init) # Check the result is correct expected_output = 0 assert (d_output == expected_output).all()
- Parameters
d_in – Device array or iterator containing the input sequence of data items
d_out – Device array (of size 1) that will store the result of the reduction
op – Callable representing the binary operator to apply
init – Numpy array storing initial value of the reduction
- Returns
A callable object that can be used to perform the reduction
- cuda.parallel.experimental.algorithms.segmented_reduce(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike, start_offsets_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, end_offsets_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable, h_init: numpy.ndarray)
Computes a device-wide segmented reduction using the specified binary
op
and initial valueinit
.Example
Below,
segmented_reduce
is used to compute the minimum value of a sequence of integers.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def min_op(a, b): return a if a < b else b dtype = np.dtype(np.int32) max_val = np.iinfo(dtype).max h_init = np.asarray(max_val, dtype=dtype) offsets = cp.array([0, 7, 11, 16], dtype=np.int64) first_segment = (8, 6, 7, 5, 3, 0, 9) second_segment = (-4, 3, 0, 1) third_segment = (3, 1, 11, 25, 8) d_input = cp.array( [*first_segment, *second_segment, *third_segment], dtype=dtype, ) start_o = offsets[:-1] end_o = offsets[1:] n_segments = start_o.size d_output = cp.empty(n_segments, dtype=dtype) # Instantiate reduction for the given operator and initial value segmented_reduce = algorithms.segmented_reduce( d_output, d_output, start_o, end_o, min_op, h_init ) # Determine temporary device storage requirements temp_storage_size = segmented_reduce( None, d_input, d_output, n_segments, start_o, end_o, h_init ) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction segmented_reduce( d_temp_storage, d_input, d_output, n_segments, start_o, end_o, h_init ) # Check the result is correct expected_output = cp.asarray([0, -4, 1], dtype=d_output.dtype) assert (d_output == expected_output).all()
- Parameters
d_in – Device array or iterator containing the input sequence of data items
d_out – Device array that will store the result of the reduction
start_offsets_in – Device array or iterator containing offsets to start of segments
end_offsets_in – Device array or iterator containing offsets to end of segments
op – Callable representing the binary operator to apply
init – Numpy array storing initial value of the reduction
- Returns
A callable object that can be used to perform the reduction
- cuda.parallel.experimental.algorithms.unary_transform(d_in: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, op: Callable)
Apply a transformation to each element of the input according to the unary operation
op
.Example
import numpy as np def op(a): return a + 1 d_in = input_array d_out = cp.empty_like(d_in) unary_transform_device(d_in, d_out, len(d_in), op) got = d_out.get() expected = unary_transform_host(d_in.get(), op) np.testing.assert_allclose(expected, got, rtol=1e-5)
- Parameters
d_in – Device array or iterator containing the input sequence of data items.
d_out – Device array or iterator to store the result of the transformation.
op – Unary operation to apply to each element of the input.
- Returns
A callable that performs the transformation.
- cuda.parallel.experimental.algorithms.unique_by_key(d_in_keys: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_in_items: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out_keys: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out_items: cuda.parallel.experimental.typing.DeviceArrayLike | cuda.parallel.experimental.iterators._iterators.IteratorBase, d_out_num_selected: cuda.parallel.experimental.typing.DeviceArrayLike, op: Callable)
Implements a device-wide unique by key operation using
d_in_keys
and the comparison operatorop
. Only the first key and its value from each run is selected and the total number of items selected is also reported.Example
Below,
unique_by_key
is used to populate the arrays of output keys and items with the first key and its corresponding item from each sequence of equal keys. It also outputs the number of items selected.import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms def compare_op(lhs, rhs): return np.uint8(lhs == rhs) h_in_keys = np.array([0, 2, 2, 9, 5, 5, 5, 8], dtype="int32") h_in_items = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype="float32") d_in_keys = cp.asarray(h_in_keys) d_in_items = cp.asarray(h_in_items) d_out_keys = cp.empty_like(d_in_keys) d_out_items = cp.empty_like(d_in_items) d_out_num_selected = cp.empty(1, np.int32) # Instantiate unique_by_key for the given keys, items, num items selected, and operator unique_by_key = algorithms.unique_by_key( d_in_keys, d_in_items, d_out_keys, d_out_items, d_out_num_selected, compare_op ) # Determine temporary device storage requirements temp_storage_size = unique_by_key( None, d_in_keys, d_in_items, d_out_keys, d_out_items, d_out_num_selected, d_in_keys.size, ) # Allocate temporary storage d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run unique_by_key unique_by_key( d_temp_storage, d_in_keys, d_in_items, d_out_keys, d_out_items, d_out_num_selected, d_in_keys.size, ) # Check the result is correct num_selected = cp.asnumpy(d_out_num_selected)[0] h_out_keys = cp.asnumpy(d_out_keys)[:num_selected] h_out_items = cp.asnumpy(d_out_items)[:num_selected] prev_key = h_in_keys[0] expected_keys = [prev_key] expected_items = [h_in_items[0]] for idx, (previous, next) in enumerate(zip(h_in_keys, h_in_keys[1:])): if previous != next: expected_keys.append(next) # add 1 since we are enumerating over pairs expected_items.append(h_in_items[idx + 1]) np.testing.assert_array_equal(h_out_keys, np.array(expected_keys)) np.testing.assert_array_equal(h_out_items, np.array(expected_items))
- Parameters
d_in_keys – Device array or iterator containing the input sequence of keys
d_in_items – Device array or iterator that contains each key’s corresponding item
d_out_keys – Device array or iterator to store the outputted keys
d_out_items – Device array or iterator to store each outputted key’s item
d_out_num_selected – Device array to store how many items were selected
op – Callable representing the equality operator
- Returns
A callable object that can be used to perform unique by key
Iterators
- cuda.parallel.experimental.iterators.CacheModifiedInputIterator(device_array, modifier)
Random Access Cache Modified Iterator that wraps a native device pointer.
Similar to https://nvidia.github.io/cccl/cub/api/classcub_1_1CacheModifiedInputIterator.html
Currently the only supported modifier is “stream” (LOAD_CS).
Example
The code snippet below demonstrates the usage of a
CacheModifiedInputIterator
:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b values = [8, 6, 7, 5, 3, 0, 9] d_input = cp.array(values, dtype=np.int32) d_output = cp.empty(1, dtype=np.int32) iterator = iterators.CacheModifiedInputIterator( d_input, modifier="stream" ) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(iterator, d_output, add_op, h_init) temp_storage_size = reduce_into(None, iterator, d_output, len(values), h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, iterator, d_output, len(values), h_init) expected_output = functools.reduce(lambda a, b: a + b, values) assert (d_output == expected_output).all()
- Parameters
device_array – CUDA device array storing the input sequence of data items
modifier – The PTX cache load modifier
prefix – An optional prefix added to the iterator’s methods to prevent name collisions.
- Returns
A
CacheModifiedInputIterator
object initialized withdevice_array
- cuda.parallel.experimental.iterators.ConstantIterator(value)
Returns an Iterator representing a sequence of constant values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1constant__iterator.html
Example
The code snippet below demonstrates the usage of a
ConstantIterator
representing the sequence[10, 10, 10]
:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b value = 10 num_items = 3 constant_it = iterators.ConstantIterator(np.int32(value)) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(constant_it, d_output, add_op, h_init) temp_storage_size = reduce_into(None, constant_it, d_output, num_items, h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, constant_it, d_output, num_items, h_init) expected_output = functools.reduce(lambda a, b: a + b, [value] * num_items) assert (d_output == expected_output).all()
- Parameters
value – The value of every item in the sequence
- Returns
A
ConstantIterator
object initialized tovalue
- cuda.parallel.experimental.iterators.CountingIterator(offset)
Returns an Iterator representing a sequence of incrementing values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1counting__iterator.html
Example
The code snippet below demonstrates the usage of a
CountingIterator
representing the sequence[10, 11, 12]
:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b first_item = 10 num_items = 3 first_it = iterators.CountingIterator(np.int32(first_item)) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(first_it, d_output, add_op, h_init) temp_storage_size = reduce_into(None, first_it, d_output, num_items, h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, first_it, d_output, num_items, h_init) expected_output = functools.reduce( lambda a, b: a + b, range(first_item, first_item + num_items) ) assert (d_output == expected_output).all()
- Parameters
offset – The initial value of the sequence
- Returns
A
CountingIterator
object initialized tooffset
- cuda.parallel.experimental.iterators.ReverseIterator(sequence)
Returns an Iterator over an array in reverse.
Similar to [std::reverse_iterator](https://en.cppreference.com/w/cpp/iterator/reverse_iterator)
Example
The code snippet below demonstrates the usage of a
ReverseIterator
:import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b h_init = np.array([0], dtype="int32") d_input = cp.array([-5, 0, 2, -3, 2, 4, 0, -1, 2, 8], dtype="int32") d_output = cp.empty_like(d_input, dtype="int32") reverse_it = iterators.ReverseIterator(d_input) # Instantiate scan, determine storage requirements, and allocate storage inclusive_scan = algorithms.inclusive_scan(reverse_it, d_output, add_op, h_init) temp_storage_size = inclusive_scan(None, reverse_it, d_output, len(d_input), h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction inclusive_scan(d_temp_storage, reverse_it, d_output, len(d_input), h_init) # Check the result is correct expected = np.asarray([8, 10, 9, 9, 13, 15, 12, 14, 14, 9]) np.testing.assert_equal(d_output.get(), expected)
- Parameters
sequence – The iterator or CUDA device array to be reversed
- Returns
A
ReverseIterator
object initialized withsequence
- cuda.parallel.experimental.iterators.TransformIterator(it, op)
Returns an Iterator representing a transformed sequence of values.
Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1transform__iterator.html
Example
The code snippet below demonstrates the usage of a
TransformIterator
composed with aCountingIterator
, transforming the sequence[10, 11, 12]
by squaring each item before reducing the output:import functools import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators def add_op(a, b): return a + b def square_op(a): return a**2 first_item = 10 num_items = 3 transform_it = iterators.TransformIterator( iterators.CountingIterator(np.int32(first_item)), square_op ) # Input sequence h_init = np.array([0], dtype=np.int32) # Initial value for the reduction d_output = cp.empty(1, dtype=np.int32) # Storage for output # Instantiate reduction, determine storage requirements, and allocate storage reduce_into = algorithms.reduce_into(transform_it, d_output, add_op, h_init) temp_storage_size = reduce_into(None, transform_it, d_output, num_items, h_init) d_temp_storage = cp.empty(temp_storage_size, dtype=np.uint8) # Run reduction reduce_into(d_temp_storage, transform_it, d_output, num_items, h_init) expected_output = functools.reduce( lambda a, b: a + b, [a**2 for a in range(first_item, first_item + num_items)] ) assert (d_output == expected_output).all()
- Parameters
it – The iterator object to be transformed
op – The transform operation
- Returns
A
TransformIterator
object to transform the items init
usingop
Utilities
- cuda.parallel.experimental.struct.gpu_struct(this: type) Type[Any]
Defines the given class as being a GpuStruct.
A GpuStruct represents a value composed of one or more other values, and is defined as a class with annotated fields (similar to a dataclass). The type of each field must be a subclass of np.number, like np.int32 or np.float64.
Arrays of GPUStruct objects can be used as inputs to cuda.parallel algorithms.
Example
The code snippet below shows how to use gpu_struct to define a MinMax type (composed of min_val, max_val values), and perform a reduction on an input array of floating point values to compute its the smallest and the largest absolute values:
import cupy as cp import numpy as np import cuda.parallel.experimental.algorithms as algorithms import cuda.parallel.experimental.iterators as iterators from cuda.parallel.experimental.struct import gpu_struct @gpu_struct class MinMax: min_val: np.float64 max_val: np.float64 def minmax_op(v1: MinMax, v2: MinMax): c_min = min(v1.min_val, v2.min_val) c_max = max(v1.max_val, v2.max_val) return MinMax(c_min, c_max) def transform_op(v): av = abs(v) return MinMax(av, av) nelems = 4096 d_in = cp.random.randn(nelems) # input values must be transformed to MinMax structures # in-place to map computation to data-parallel reduction # algorithm that requires commutative binary operation # with both operands having the same type. tr_it = iterators.TransformIterator(d_in, transform_op) d_out = cp.empty(tuple(), dtype=MinMax.dtype) # initial value set with identity elements of # minimum and maximum operators h_init = MinMax(np.inf, -np.inf) # get algorithm object cccl_sum = algorithms.reduce_into(tr_it, d_out, minmax_op, h_init) # allocated needed temporary tmp_sz = cccl_sum(None, tr_it, d_out, nelems, h_init) tmp_storage = cp.empty(tmp_sz, dtype=cp.uint8) # invoke the reduction algorithm cccl_sum(tmp_storage, tr_it, d_out, nelems, h_init) # display values computed on the device actual = d_out.get() h = np.abs(d_in.get()) expected = np.asarray([(h.min(), h.max())], dtype=MinMax.dtype) assert actual == expected