CUDA Kernel API#
Kernel declaration#
The @cuda.jit decorator is used to create a CUDA dispatcher object that can
be configured and launched:
- numba.cuda.jit(
 - func_or_sig=None,
 - device=False,
 - inline='never',
 - forceinline=False,
 - link=[],
 - debug=None,
 - opt=None,
 - lineinfo=False,
 - cache=False,
 - launch_bounds=None,
 - lto=None,
 - **kws,
 JIT compile a Python function for CUDA GPUs.
- Parameters:
 func_or_sig –
A function to JIT compile, or signatures of a function to compile. If a function is supplied, then a
Dispatcheris returned. Otherwise,func_or_sigmay be a signature or a list of signatures, and a function is returned. The returned function accepts another function, which it will compile and then return aDispatcher. See JIT functions for more information about passing signatures.Note
A kernel cannot have any return value.
device (bool) – Indicates whether this is a device function.
inline (str) – Enables inlining at the Numba IR level when set to
"always". See Notes on Inlining.forceinline (bool) – Enables inlining at the NVVM IR level when set to
True. This is accomplished by adding thealwaysinlinefunction attribute to the function definition.link (list) – A list of files containing PTX or CUDA C/C++ source to link with the function
debug – If True, check for exceptions thrown when executing the kernel. Since this degrades performance, this should only be used for debugging purposes. If set to True, then
optshould be set to False. Defaults to False. (The default value can be overridden by setting environment variableNUMBA_CUDA_DEBUGINFO=1.)fastmath – When True, enables fastmath optimizations as outlined in the CUDA Fast Math documentation.
max_registers – Request that the kernel is limited to using at most this number of registers per thread. The limit may not be respected if the ABI requires a greater number of registers than that requested. Useful for increasing occupancy.
opt (bool) – Whether to compile with optimization enabled. If unspecified, the OPT configuration variable is decided by
NUMBA_OPT`; all non-zero values will enable optimization.lineinfo (bool) – If True, generate a line mapping between source code and assembly code. This enables inspection of the source code in NVIDIA profiling tools and correlation with program counter sampling.
cache (bool) – If True, enables the file-based cache for this function.
launch_bounds (int | tuple[int]) –
Kernel launch bounds, specified as a scalar or a tuple of between one and three items. Tuple items provide:
The maximum number of threads per block,
The minimum number of blocks per SM,
The maximum number of blocks per cluster.
If a scalar is provided, it is used as the maximum number of threads per block.
lto (bool) – Whether to enable LTO. If unspecified, LTO is enabled by default when nvjitlink is available, except for kernels where
debug=True.
Dispatcher objects#
The usual syntax for configuring a Dispatcher with a launch configuration uses subscripting, with the arguments being as in the following:
# func is some function decorated with @cuda.jit
func[griddim, blockdim, stream, sharedmem]
The griddim and blockdim arguments specify the size of the grid and
thread blocks, and may be either integers or tuples of length up to 3. The
stream parameter is an optional stream on which the kernel will be launched,
and the sharedmem parameter specifies the size of dynamic shared memory in
bytes.
Subscripting the Dispatcher returns a configuration object that can be called with the kernel arguments:
configured = func[griddim, blockdim, stream, sharedmem]
configured(x, y, z)
However, it is more idiomatic to configure and call the kernel within a single statement:
func[griddim, blockdim, stream, sharedmem](x, y, z)
This is similar to launch configuration in CUDA C/C++:
func<<<griddim, blockdim, sharedmem, stream>>>(x, y, z)
Note
The order of stream and sharedmem are reversed in Numba
compared to in CUDA C/C++.
Dispatcher objects also provide several utility methods for inspection and creating a specialized instance:
- class numba.cuda.dispatcher.CUDADispatcher(
 - py_func,
 - targetoptions,
 - pipeline_class=<class 'numba.cuda.compiler.CUDACompiler'>,
 CUDA Dispatcher object. When configured and called, the dispatcher will specialize itself for the given arguments (if no suitable specialized version already exists) & compute capability, and launch on the device associated with the current context.
Dispatcher objects are not to be constructed by the user, but instead are created using the
numba.cuda.jit()decorator.- property extensions#
 A list of objects that must have a prepare_args function. When a specialized kernel is called, each argument will be passed through to the prepare_args (from the last object in this list to the first). The arguments to prepare_args are:
ty the numba type of the argument
val the argument value itself
stream the CUDA stream used for the current call to the kernel
retr a list of zero-arg functions that you may want to append post-call cleanup work to.
The prepare_args function must return a tuple (ty, val), which will be passed in turn to the next right-most extension. After all the extensions have been called, the resulting (ty, val) will be passed into Numba’s default argument marshalling logic.
- forall(ntasks, tpb=0, stream=0, sharedmem=0)#
 Returns a 1D-configured dispatcher for a given number of tasks.
This assumes that:
the kernel maps the Global Thread ID
cuda.grid(1)to tasks on a 1-1 basis.the kernel checks that the Global Thread ID is upper-bounded by
ntasks, and does nothing if it is not.
- Parameters:
 ntasks – The number of tasks.
tpb – The size of a block. An appropriate value is chosen if this parameter is not supplied.
stream – The stream on which the configured dispatcher will be launched.
sharedmem – The number of bytes of dynamic shared memory required by the kernel.
- Returns:
 A configured dispatcher, ready to launch on a set of arguments.
- get_const_mem_size(signature=None)#
 Returns the size in bytes of constant memory used by this kernel for the device in the current context.
- Parameters:
 signature – The signature of the compiled kernel to get constant memory usage for. This may be omitted for a specialized kernel.
- Returns:
 The size in bytes of constant memory allocated by the compiled variant of the kernel for the given signature and current device.
- get_local_mem_per_thread(signature=None)#
 Returns the size in bytes of local memory per thread for this kernel.
- Parameters:
 signature – The signature of the compiled kernel to get local memory usage for. This may be omitted for a specialized kernel.
- Returns:
 The amount of local memory allocated by the compiled variant of the kernel for the given signature and current device.
- get_max_threads_per_block(signature=None)#
 Returns the maximum allowable number of threads per block for this kernel. Exceeding this threshold will result in the kernel failing to launch.
- Parameters:
 signature – The signature of the compiled kernel to get the max threads per block for. This may be omitted for a specialized kernel.
- Returns:
 The maximum allowable threads per block for the compiled variant of the kernel for the given signature and current device.
- get_regs_per_thread(signature=None)#
 Returns the number of registers used by each thread in this kernel for the device in the current context.
- Parameters:
 signature – The signature of the compiled kernel to get register usage for. This may be omitted for a specialized kernel.
- Returns:
 The number of registers used by the compiled variant of the kernel for the given signature and current device.
Returns the size in bytes of statically allocated shared memory for this kernel.
- Parameters:
 signature – The signature of the compiled kernel to get shared memory usage for. This may be omitted for a specialized kernel.
- Returns:
 The amount of shared memory allocated by the compiled variant of the kernel for the given signature and current device.
- inspect_asm(signature=None)#
 Return this kernel’s PTX assembly code for for the device in the current context.
- Parameters:
 signature – A tuple of argument types.
- Returns:
 The PTX code for the given signature, or a dict of PTX codes for all previously-encountered signatures.
- inspect_llvm(signature=None)#
 Return the LLVM IR for this kernel.
- Parameters:
 signature – A tuple of argument types.
- Returns:
 The LLVM IR for the given signature, or a dict of LLVM IR for all previously-encountered signatures.
- inspect_sass(signature=None)#
 Return this kernel’s SASS assembly code for for the device in the current context.
- Parameters:
 signature – A tuple of argument types.
- Returns:
 The SASS code for the given signature, or a dict of SASS codes for all previously-encountered signatures.
SASS for the device in the current context is returned.
Requires nvdisasm to be available on the PATH.
- inspect_types(file=None)#
 Produce a dump of the Python source of this function annotated with the corresponding Numba IR and type information. The dump is written to file, or sys.stdout if file is None.
- specialize(*args)#
 Create a new instance of this dispatcher specialized for the given args.
- property specialized#
 True if the Dispatcher has been specialized.
Kernel Arguments#
The following types are supported for kernel arguments:
Arrays of scalars and structured types. Considerations:
NumPy arrays will be copied to the device prior to kernel invocation, and copied back after the completion of kernel execution. Copying data between the host and device within the flow of kernel launches is quite inefficient, so this will cause a
NumbaPerformanceWarningto be emitted. It is recommended that data is copied to the device prior to kernel launch and copied back as required, or that managed memory is used.Any object that exposes the CUDA Array Interface will be treated as a device array and handled accordingly.
Scalars. This includes floating point types, signed and unsigned integers, complex types, and enum members.
Records. These may be either a NumPy record or a record obtained from a Numba device array holding records. Similar considerations to those for NumPy arrays apply with respect to copying data between host and device.
Tuples, where the tuples contain supported types. Nesting tuples, and tuple subclasses (like
namedtuple) are supported.Pointers. In order to use pointer arguments, an explicit signature must be provided when declaring the kernel. See Passing a pointer to a kernel.
Intrinsic Attributes and Functions#
The remainder of the attributes and functions in this section may only be called from within a CUDA Kernel.
Thread Indexing#
- numba.cuda.threadIdx#
 The thread indices in the current thread block, accessed through the attributes
x,y, andz. Each index is an integer spanning the range from 0 inclusive to the corresponding value of the attribute innumba.cuda.blockDimexclusive.
- numba.cuda.blockIdx#
 The block indices in the grid of thread blocks, accessed through the attributes
x,y, andz. Each index is an integer spanning the range from 0 inclusive to the corresponding value of the attribute innumba.cuda.gridDimexclusive.
- numba.cuda.blockDim#
 The shape of a block of threads, as declared when instantiating the kernel. This value is the same for all threads in a given kernel, even if they belong to different blocks (i.e. each block is “full”).
- numba.cuda.gridDim#
 The shape of the grid of blocks, accessed through the attributes
x,y, andz.
- numba.cuda.laneid#
 The thread index in the current warp, as an integer spanning the range from 0 inclusive to the
numba.cuda.warpsizeexclusive.
- numba.cuda.warpsize#
 The size in threads of a warp on the GPU. Currently this is always 32.
- numba.cuda.grid(ndim)#
 Return the absolute position of the current thread in the entire grid of blocks. ndim should correspond to the number of dimensions declared when instantiating the kernel. If ndim is 1, a single integer is returned. If ndim is 2 or 3, a tuple of the given number of integers is returned.
Computation of the first integer is as follows:
cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
and is similar for the other two indices, but using the
yandzattributes.
- numba.cuda.gridsize(ndim)#
 Return the absolute size (or shape) in threads of the entire grid of blocks. ndim should correspond to the number of dimensions declared when instantiating the kernel.
Computation of the first integer is as follows:
cuda.blockDim.x * cuda.gridDim.x
and is similar for the other two indices, but using the
yandzattributes.
Memory Management#
Creates an array in the local memory space of the CUDA kernel with the given
shapeanddtype.Returns an array with its content uninitialized.
Note
All threads in the same thread block sees the same array.
- numba.cuda.local.array(shape, dtype)#
 Creates an array in the local memory space of the CUDA kernel with the given
shapeanddtype.Returns an array with its content uninitialized.
Note
Each thread sees a unique array.
- numba.cuda.const.array_like(ary)#
 Copies the
aryinto constant memory space on the CUDA kernel at compile time.Returns an array like the
aryargument.Note
All threads and blocks see the same array.
Synchronization and Atomic Operations#
- numba.cuda.atomic.add(array, idx, value)#
 Perform
array[idx] += value. Support int32, int64, float32 and float64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multiple dimensional arrays. The number of element inidxmust match the number of dimension ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.sub(array, idx, value)#
 Perform
array[idx] -= value. Supports int32, int64, float32 and float64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.and_(array, idx, value)#
 Perform
array[idx] &= value. Supports int32, uint32, int64, and uint64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.or_(array, idx, value)#
 Perform
array[idx] |= value. Supports int32, uint32, int64, and uint64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.xor(array, idx, value)#
 Perform
array[idx] ^= value. Supports int32, uint32, int64, and uint64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.exch(array, idx, value)#
 Perform
array[idx] = value. Supports int32, uint32, int64, and uint64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.inc(array, idx, value)#
 Perform
array[idx] = (0 if array[idx] >= value else array[idx] + 1). Supports uint32, and uint64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.dec(array, idx, value)#
 Perform
array[idx] = (value if (array[idx] == 0) or (array[idx] > value) else array[idx] - 1). Supports uint32, and uint64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.max(array, idx, value)#
 Perform
array[idx] = max(array[idx], value). Support int32, int64, float32 and float64 only. Theidxargument can be an integer or a tuple of integer indices for indexing into multiple dimensional arrays. The number of element inidxmust match the number of dimension ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic load.
- numba.cuda.atomic.cas(array, idx, old, value)#
 Perform
if array[idx] == old: array[idx] = value. Supports int32, int64, uint32, uint64 indexes only. Theidxargument can be an integer or a tuple of integer indices for indexing into multi-dimensional arrays. The number of elements inidxmust match the number of dimensions ofarray.Returns the value of
array[idx]before storing the new value. Behaves like an atomic compare and swap.
- numba.cuda.syncthreads()#
 Synchronize all threads in the same thread block. This function implements the same pattern as barriers in traditional multi-threaded programming: this function waits until all threads in the block call it, at which point it returns control to all its callers.
- numba.cuda.syncthreads_count(predicate)#
 An extension to
numba.cuda.syncthreadswhere the return value is a count of the threads wherepredicateis true.
- numba.cuda.syncthreads_and(predicate)#
 An extension to
numba.cuda.syncthreadswhere 1 is returned ifpredicateis true for all threads or 0 otherwise.
- numba.cuda.syncthreads_or(predicate)#
 An extension to
numba.cuda.syncthreadswhere 1 is returned ifpredicateis true for any thread or 0 otherwise.Warning
All syncthreads functions must be called by every thread in the thread-block. Falling to do so may result in undefined behavior.
Cooperative Groups#
- numba.cuda.cg.this_grid()#
 Get the current grid group.
- Returns:
 The current grid group
- Return type:
 
- class numba.cuda.cg.GridGroup#
 A grid group. Users should not construct a GridGroup directly - instead, get the current grid group using
cg.this_grid().- sync()#
 Synchronize the current grid group.
Memory Fences#
The memory fences are used to guarantee the effect of memory operations are visible by other threads within the same thread-block, the same GPU device, and the same system (across GPUs on global memory). Memory loads and stores are guaranteed to not move across the memory fences by optimization passes.
Warning
The memory fences are considered to be advanced API and most
usercases should use the thread barrier (e.g. syncthreads()).
- numba.cuda.threadfence()#
 A memory fence at device level (within the GPU).
- numba.cuda.threadfence_block()#
 A memory fence at thread block level.
- numba.cuda.threadfence_system()#
 A memory fence at system level (across GPUs).
Warp Intrinsics#
The argument membermask is a 32 bit integer mask with each bit
corresponding to a thread in the warp, with 1 meaning the thread is in the
subset of threads within the function call. The membermask must be all 1 if
the GPU compute capability is below 7.x.
- numba.cuda.syncwarp(membermask)#
 Synchronize a masked subset of the threads in a warp.
- numba.cuda.all_sync(membermask, predicate)#
 If the
predicateis true for all threads in the masked warp, then a non-zero value is returned, otherwise 0 is returned.
- numba.cuda.any_sync(membermask, predicate)#
 If the
predicateis true for any thread in the masked warp, then a non-zero value is returned, otherwise 0 is returned.
- numba.cuda.eq_sync(membermask, predicate)#
 If the boolean
predicateis the same for all threads in the masked warp, then a non-zero value is returned, otherwise 0 is returned.
- numba.cuda.ballot_sync(membermask, predicate)#
 Returns a mask of all threads in the warp whose
predicateis true, and are within the given mask.
- numba.cuda.shfl_sync(membermask, value, src_lane)#
 Shuffles
valueacross the masked warp and returns thevaluefromsrc_lane. If this is outside the warp, then the givenvalueis returned.
- numba.cuda.shfl_up_sync(membermask, value, delta)#
 Shuffles
valueacross the masked warp and returns thevaluefromlaneid - delta. If this is outside the warp, then the givenvalueis returned.
- numba.cuda.shfl_down_sync(membermask, value, delta)#
 Shuffles
valueacross the masked warp and returns thevaluefromlaneid + delta. If this is outside the warp, then the givenvalueis returned.
- numba.cuda.shfl_xor_sync(membermask, value, lane_mask)#
 Shuffles
valueacross the masked warp and returns thevaluefromlaneid ^ lane_mask.
- numba.cuda.match_any_sync(membermask, value, lane_mask)#
 Returns a mask of threads that have same
valueas the givenvaluefrom within the masked warp.
- numba.cuda.match_all_sync(membermask, value, lane_mask)#
 Returns a tuple of (mask, pred), where mask is a mask of threads that have same
valueas the givenvaluefrom within the masked warp, if they all have the same value, otherwise it is 0. And pred is a boolean of whether or not all threads in the mask warp have the same warp.
- numba.cuda.activemask()#
 Returns a 32-bit integer mask of all currently active threads in the calling warp. The Nth bit is set if the Nth lane in the warp is active when activemask() is called. Inactive threads are represented by 0 bits in the returned mask. Threads which have exited the kernel are always marked as inactive.
- numba.cuda.lanemask_lt()#
 Returns a 32-bit integer mask of all lanes (including inactive ones) with ID less than the current lane.
Integer Intrinsics#
A subset of the CUDA Math API’s integer intrinsics are available. For further documentation, including semantics, please refer to the CUDA Toolkit documentation.
- numba.cuda.popc(x)#
 Returns the number of bits set in
x.
- numba.cuda.brev(x)#
 Returns the reverse of the bit pattern of
x. For example,0b10110110becomes0b01101101.
- numba.cuda.clz(x)#
 Returns the number of leading zeros in
x.
- numba.cuda.ffs(x)#
 Returns the position of the first (least significant) bit set to 1 in
x, where the least significant bit position is 1.ffs(0)returns 0.
Floating Point Intrinsics#
A subset of the CUDA Math API’s floating point intrinsics are available. For further documentation, including semantics, please refer to the single and double precision parts of the CUDA Toolkit documentation.
- numba.cuda.fma()#
 Perform the fused multiply-add operation. Named after the
fmaandfmafin the C api, but maps to thefma.rn.f32andfma.rn.f64(round-to-nearest-even) PTX instructions.
- numba.cuda.cbrt(x)#
 Perform the cube root operation, x ** (1/3). Named after the functions
cbrtandcbrtfin the C api. Supports float32, and float64 arguments only.
16-bit Floating Point Intrinsics#
Warning
Starting numba 0.18, LTO is required for performant float16 operations.
The functions in the cuda.fp16 module are used to operate on 16-bit
floating point operands. These functions return a 16-bit floating point result.
To determine whether Numba supports compiling code that uses the float16
type in the current configuration, use:
- numba.cuda.is_float16_supported()#
 Return
Trueif 16-bit floats are supported,Falseotherwise.
To check whether a device supports float16, use its
supports_float16
attribute.
- numba.cuda.fp16.hfma(a, b, c)#
 Perform the fused multiply-add operation
(a * b) + con 16-bit floating point arguments in round to nearest mode. Maps to thefma.rn.f16PTX instruction.Returns the 16-bit floating point result of the fused multiply-add.
- numba.cuda.fp16.hadd(a, b)#
 Perform the add operation
a + bon 16-bit floating point arguments in round to nearest mode. Maps to theadd.f16PTX instruction.Returns the 16-bit floating point result of the addition.
- numba.cuda.fp16.hsub(a, b)#
 Perform the subtract operation
a - bon 16-bit floating point arguments in round to nearest mode. Maps to thesub.f16PTX instruction.Returns the 16-bit floating point result of the subtraction.
- numba.cuda.fp16.hmul(a, b)#
 Perform the multiply operation
a * bon 16-bit floating point arguments in round to nearest mode. Maps to themul.f16PTX instruction.Returns the 16-bit floating point result of the multiplication.
- numba.cuda.fp16.hdiv(a, b)#
 Perform the divide operation
a / bon 16-bit floating point arguments in round to nearest mode.Returns the 16-bit floating point result of the division.
- numba.cuda.fp16.hneg(a)#
 Perform the negation operation
-aon the 16-bit floating point argument. Maps to theneg.f16PTX instruction.Returns the 16-bit floating point result of the negation.
- numba.cuda.fp16.habs(a)#
 Perform the absolute value operation
|a|on the 16-bit floating point argument.Returns the 16-bit floating point result of the absolute value operation.
- numba.cuda.fp16.hsin(a)#
 Calculates the trigonometry sine function of the 16-bit floating point argument.
Returns the 16-bit floating point result of the sine operation.
- numba.cuda.fp16.hcos(a)#
 Calculates the trigonometry cosine function of the 16-bit floating point argument.
Returns the 16-bit floating point result of the cosine operation.
- numba.cuda.fp16.hlog(a)#
 Calculates the natural logarithm of the 16-bit floating point argument.
Returns the 16-bit floating point result of the natural log operation.
- numba.cuda.fp16.hlog10(a)#
 Calculates the base 10 logarithm of the 16-bit floating point argument.
Returns the 16-bit floating point result of the log base 10 operation.
- numba.cuda.fp16.hlog2(a)#
 Calculates the base 2 logarithm on the 16-bit floating point argument.
Returns the 16-bit floating point result of the log base 2 operation.
- numba.cuda.fp16.hexp(a)#
 Calculates the natural exponential operation of the 16-bit floating point argument.
Returns the 16-bit floating point result of the exponential operation.
- numba.cuda.fp16.hexp10(a)#
 Calculates the base 10 exponential of the 16-bit floating point argument.
Returns the 16-bit floating point result of the exponential operation.
- numba.cuda.fp16.hexp2(a)#
 Calculates the base 2 exponential of the 16-bit floating point argument.
Returns the 16-bit floating point result of the exponential operation.
- numba.cuda.fp16.hfloor(a)#
 Calculates the floor operation, the largest integer less than or equal to
a, on the 16-bit floating point argument.Returns the 16-bit floating point result of the floor operation.
- numba.cuda.fp16.hceil(a)#
 Calculates the ceiling operation, the smallest integer greater than or equal to
a, on the 16-bit floating point argument.Returns the 16-bit floating point result of the ceil operation.
- numba.cuda.fp16.hsqrt(a)#
 Calculates the square root operation of the 16-bit floating point argument.
Returns the 16-bit floating point result of the square root operation.
- numba.cuda.fp16.hrsqrt(a)#
 Calculates the reciprocal of the square root of the 16-bit floating point argument.
Returns the 16-bit floating point result of the reciprocal square root operation.
- numba.cuda.fp16.hrcp(a)#
 Calculates the reciprocal of the 16-bit floating point argument.
Returns the 16-bit floating point result of the reciprocal.
- numba.cuda.fp16.hrint(a)#
 Round the input 16-bit floating point argument to nearest integer value.
Returns the 16-bit floating point result of the rounding.
- numba.cuda.fp16.htrunc(a)#
 Truncate the input 16-bit floating point argument to the nearest integer that does not exceed the input argument in magnitude.
Returns the 16-bit floating point result of the truncation.
- numba.cuda.fp16.heq(a, b)#
 Perform the comparison operation
a == bon 16-bit floating point arguments.Returns a boolean.
- numba.cuda.fp16.hne(a, b)#
 Perform the comparison operation
a != bon 16-bit floating point arguments.Returns a boolean.
- numba.cuda.fp16.hgt(a, b)#
 Perform the comparison operation
a > bon 16-bit floating point arguments.Returns a boolean.
- numba.cuda.fp16.hge(a, b)#
 Perform the comparison operation
a >= bon 16-bit floating point arguments.Returns a boolean.
- numba.cuda.fp16.hlt(a, b)#
 Perform the comparison operation
a < bon 16-bit floating point arguments.Returns a boolean.
- numba.cuda.fp16.hle(a, b)#
 Perform the comparison operation
a <= bon 16-bit floating point arguments.Returns a boolean.
- numba.cuda.fp16.hmax(a, b)#
 Perform the operation
a if a > b else b.Returns a 16-bit floating point value.
- numba.cuda.fp16.hmin(a, b)#
 Perform the operation
a if a < b else b.Returns a 16-bit floating point value.
Control Flow Instructions#
A subset of the CUDA’s control flow instructions are directly available as
intrinsics. Avoiding branches is a key way to improve CUDA performance, and
using these intrinsics mean you don’t have to rely on the nvcc optimizer
identifying and removing branches. For further documentation, including
semantics, please refer to the relevant CUDA Toolkit documentation.
- numba.cuda.selp()#
 Select between two expressions, depending on the value of the first argument. Similar to LLVM’s
selectinstruction.
Timer Intrinsics#
- numba.cuda.nanosleep(ns)#
 Suspends the thread for a sleep duration approximately close to the delay
ns, specified in nanoseconds.