gemm.h

Functions for matrix multiplication.

Typedefs

typedef void *NVTEMatmulConfig

Configuration for matrix multiplication.

typedef void *NVTEGroupedMatmulConfig

Configuration for grouped matrix multiplication.

Enums

enum NVTEMatmulConfigAttribute

Type of option for matrix multiplication.

Values:

enumerator kNVTEMatmulConfigBiasTensor

Bias tensor

If provided, the bias tensor is applied in the GEMM epilogue.

enumerator kNVTEMatmulConfigDBiasTensor

Bias gradient tensor

If provided, the bias gradient tensor will be filled in the GEMM epilogue.

enumerator kNVTEMatmulConfigWithGELUEpilogue

Whether to compute GELU in GEMM epilogue.

enumerator kNVTEMatmulConfigWithDGELUEpilogue

Whether to compute GELU backward in GEMM epilogue.

enumerator kNVTEMatmulConfigEpilogueAuxTensor

Auxilliary tensor for GEMM epilogue.

For GELU, this will be filled with the GELU input. For GELU backward, this is expected to already be filled with the GELU input.

enumerator kNVTEMatmulConfigUseSplitAccumulator

Whether to use split accumulator for FP8 GEMM.

enumerator kNVTEMatmulConfigSMCount

Number of streaming multiprocessors to use in GEMM kernel.

enumerator kNVTEMatmulConfigNumAttributes
enum NVTEGroupedMatmulConfigAttribute

Type of option for grouped matrix multiplication.

Values:

enumerator kNVTEGroupedMatmulConfigAvgM

Average M dimension hint

Optional hint for average M dimension across all matrices in the group. Used by cuBLASLt for algorithm selection heuristics. If not set, computed automatically from D’s logical shape.

enumerator kNVTEGroupedMatmulConfigAvgN

Average N dimension hint

Optional hint for average N dimension across all matrices in the group. Used by cuBLASLt for algorithm selection heuristics. If not set, computed automatically from D’s logical shape.

enumerator kNVTEGroupedMatmulConfigAvgK

Average K (reduction) dimension hint

Optional hint for average K dimension across all matrices in the group. Used by cuBLASLt for algorithm selection heuristics. If not set, computed automatically from A’s logical shape.

enumerator kNVTEGroupedMatmulConfigSMCount

Number of streaming multiprocessors to use in GEMM kernel.

enumerator kNVTEGroupedMatmulConfigUseSplitAccumulator

Split accumulator mode. Only taken into account on Hopper. Default: true.

enumerator kNVTEGroupedMatmulConfigNumAttributes

Functions

NVTEMatmulConfig nvte_create_matmul_config()

Create a matrix multiplication configuration.

void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written)

Query an option in matrix multiplication configuration.

Parameters:
  • config[in] Matrix multiplication configuration.

  • attr[in] Option type.

  • buf[out] Memory address to write option value to. Ignored if NULL.

  • size_in_bytes[in] Size of buf.

  • size_written[out] Number of bytes that have been written to buf. If buf is NULL, then the number of bytes that would have been written.

void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, const void *buf, size_t size_in_bytes)

Set an option in matrix multiplication configuration.

Parameters:
  • [in/out] – config Matrix multiplication configuration.

  • attr[in] Option type.

  • buf[in] Memory address to read option value from.

  • size_in_bytes[in] Size of buf.

void nvte_destroy_matmul_config(NVTEMatmulConfig config)

Destroy a matrix multiplication configuration.

NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config()

Create a grouped matrix multiplication configuration.

void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, NVTEGroupedMatmulConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written)

Query an option in grouped matrix multiplication configuration.

Parameters:
  • config[in] Grouped matrix multiplication configuration.

  • attr[in] Option type.

  • buf[out] Memory address to write option value. Ignored if NULL.

  • size_in_bytes[in] Size of buf.

  • size_written[out] Number of bytes that have been written to buf. If buf is NULL, then the number of bytes that would have been written.

void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, NVTEGroupedMatmulConfigAttribute attr, const void *buf, size_t size_in_bytes)

Set an option in grouped matrix multiplication configuration.

Parameters:
  • config[in] Grouped matrix multiplication configuration.

  • attr[in] Option type.

  • buf[out] Memory address to read option value.

  • size_in_bytes[in] Size of buf.

void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config)

Destroy a grouped matrix multiplication configuration.

void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)

Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).

This has been deprecated in favor of nvte_cublas_gemm_v2.

Computes:

  • D = AB if both bias and pre_gelu_out are empty tensors

  • D = AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(AB + bias) if both bias and pre_gelu_out are not empty tensors

Parameters:
  • A[in] The A matrix.

  • B[in] The B matrix.

  • D[inout] Output matrix.

  • bias[in] Bias tensor.

  • pre_gelu_out[inout] Output matrix before GELU activation.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] Workspace tensor.

  • accumulate[in] Whether to accumulate the result into the D matrix.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • stream[in] CUDA stream used for the operation.

void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream)

Compute matrix multiplication of 2 matrices, potentially fused with other operations.

Computes:

  • D = alpha * op(A) * op(B) + beta * C

Parameters:
  • transa[in] Whether to transpose A matrix.

  • transb[in] Whether to transpose B matrix.

  • alpha[in] Scaling factor applied to matmul output.

  • A[in] A matrix.

  • B[in] B matrix.

  • beta[in] Scaling factor applied to C matrix.

  • C[in] C matrix.

  • D[out] Output matrix.

  • workspace[in] Workspace tensor.

  • config[in] Additional configuration.

  • stream[in] CUDA stream used for the operation.

void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)

Compute matrix multiplication of 2 matrices, potentially fused with other operations, allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)

This has been deprecated in favor of nvte_cublas_gemm_v2.

Computes:

  • D = alpha*AB if both bias and pre_gelu_out are empty tensors

  • D = alpha*AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(alpha*AB + bias) if both bias and pre_gelu_out are not empty tensors

Parameters:
  • A[in] The A matrix.

  • B[in] The B matrix.

  • D[inout] Output matrix.

  • bias[in] Bias tensor.

  • pre_gelu_out[inout] Output matrix before GELU activation.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] Workspace tensor.

  • alpha[in] Scaling factor applied to the result of the GEMM

  • beta[in] Scaling factor applied to original value of D when accumulating into it. beta=0 means no accumulation.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • stream[in] CUDA stream used for the operation.

void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const NVTETensor counter, cudaStream_t stream)

Compute matrix multiplication of 2 matrices with chunking and atomic counters.

Computes:

  • D = AB if both bias and pre_gelu_out are empty tensors

  • D = AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(AB + bias) if both bias and pre_gelu_out are not empty tensors

Warning

Cublas atomic gemm uses a beta API and is not tested for all use cases.

Parameters:
  • A[in] The A matrix.

  • B[in] The B matrix.

  • D[inout] Output matrix.

  • bias[in] Bias tensor.

  • pre_gelu_out[inout] Output matrix before GELU activation.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] Workspace tensor.

  • accumulate[in] Whether to accumulate the result into the D matrix.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • m_split[in] Number of chunks/splits along m-dimension for Atomic GEMM.

  • n_split[in] Number of chunks/splits along n-dimension for Atomic GEMM.

  • gemm_producer[in] Whether Atomic GEMM is the producer or consumer.

  • counter[inout] counter[chunk_i]=0 indicates chunk_i has been produced.

  • stream[in] CUDA stream used for the operation.

void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream)

Compute multiple pairs of matrix multiplication, potentially fused with other operations, on multiple streams.

Computes:

  • D = AB if both bias and pre_gelu_out are empty tensors

  • D = AB + bias if pre_gelu_out is empty and bias is not empty

  • D = GELU(AB + bias) if both bias and pre_gelu_out are not empty tensors

Parameters:
  • A[in] The list of A matrices.

  • B[in] The list of B matrices.

  • D[inout] List of output matrices.

  • bias[in] List of bias tensors.

  • pre_gelu_out[inout] List of output matrix before GELU activation.

  • num_gemms[in] Number of GEMMs to compute.

  • transa[in] Whether A matrix is transposed.

  • transb[in] Whether B matrix is transposed.

  • grad[in] Whether this operation is part of the gradient computation.

  • workspace[out] List of workspace tensors.

  • accumulate[in] Whether to accumulate the result into the D matrix.

  • use_split_accumulator[in] Whether to use split accumulator in the FP8 GEMM.

  • math_sm_count[in] Number of GPU SMs to use (default=0: use cuBLAS heuristics)

  • stream[in] CUDA stream to wait on.

size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors)

Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C.

Performs batched GEMM on a collection of matrices with potentially different shapes. All tensors in the group must have compatible dimensions for matrix multiplication. Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous memory layout and shape metadata.

Requirements:

  • cuBLAS 13.2+ (CUDA 13.1+)

  • Blackwell (SM100) or newer GPU architecture

  • A, B, C (if provided), D must have the same num_tensors

  • For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]

  • Shape compatibility: if transa=false, transb=false:

    • A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])

Return the required size in bytes for the setup workspace of grouped GEMM.

The setup workspace stores pointer arrays and per-matrix dimension arrays used by the grouped GEMM kernel. Its size depends only on the number of tensors (GEMMs) in the group and is independent of matrix dimensions.

Pass the result as the size of the workspace_setup tensor in nvte_grouped_gemm.

Note

Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture. Will error at runtime if compiled with an older cuBLAS version or run on a pre-Blackwell GPU.

Parameters:
  • A[in] Input grouped tensor A.

  • transa[in] Whether to transpose A matrices.

  • B[in] Input grouped tensor B.

  • transb[in] Whether to transpose B matrices.

  • C[in] Input grouped tensor C (can be NULL for beta=0).

  • D[out] Output grouped tensor D.

  • alpha[in] Scale multipliers for A @ B (NVTETensor with num_tensors elements).

  • beta[in] Scale multipliers for C (NVTETensor with num_tensors elements).

  • workspace_setup[in] Workspace tensor for pointer array setup.

  • workspace_cublas[in] Workspace tensor for cuBLAS operations.

  • config[in] Additional configuration (can be NULL for defaults).

  • stream[in] CUDA stream for the operation.

  • num_tensors[in] Number of tensors (GEMMs) in the group.

Returns:

Required size in bytes for workspace_setup.

void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream)

Convert a device array of int32 values to int64 values.

Useful for preparing group_sizes for nvte_grouped_gemm when the caller holds int32 sizes and needs int64 values on the device.

Parameters:
  • src[in] Device pointer to source int32 array.

  • dst[out] Device pointer to destination int64 array.

  • n[in] Number of elements.

  • stream[in] CUDA stream.

void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, cudaStream_t stream)
void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num_a_tensors, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, cudaStream_t stream)

Grouped matrix multiplication with discrete A input tensors.

Identical to nvte_grouped_gemm, but A is provided as a list of tensors instead of NVTEGroupedTensor. This enables discrete per-expert weights as inputA for Grouped GEMM.

Parameters:
  • A_list[in] List of A tensors (length = num_tensors).

  • num_a_tensors[in] Number of tensors in A_list.

void nvte_grouped_gemm_with_discrete_out(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTETensor *C_list, size_t num_c_tensors, NVTETensor *D_list, size_t num_d_tensors, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config, cudaStream_t stream)

Grouped matrix multiplication with discrete output tensors.

Identical to nvte_grouped_gemm, but C and D are provided as lists of tensors instead of NVTEGroupedTensor. This enables accumulation into non-contiguous per-expert buffers (for wgrads).

Note

All tensors in C_list and D_list must share the same dtype.

Parameters:
  • C_list[in] Optional list of C tensors (length = num_tensors).

  • num_c_tensors[in] Number of tensors in C_list (Can be 0 if C is not provided).

  • D_list[out] List of D tensors (length = num_tensors).

  • num_d_tensors[in] Number of tensors in D_list.

void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, cudaStream_t stream)

Grouped bias add for grouped GEMM outputs.

Requires uniform last-dimension across all output tensors and bias tensors.

namespace transformer_engine

Namespace containing C++ API of Transformer Engine.

Functions

void nvte_cublas_handle_init()

TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing region. This function is a helper to call cublasCreate() which allocate memory for the handle. The function will be called in the initialize phase of the related XLA custom calls.

struct GroupedMatmulConfigWrapper
#include <gemm.h>

C++ wrapper for NVTEGroupedMatmulConfig.

Public Functions

inline GroupedMatmulConfigWrapper()
GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper&) = delete
GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper&) = delete
inline GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other)
inline GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other)
inline ~GroupedMatmulConfigWrapper()
inline operator NVTEGroupedMatmulConfig() const noexcept

Get the underlying NVTEGroupedMatmulConfig.

Returns:

NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper.

inline void set_avg_m(int64_t avg_m)

Set average M dimension hint for algorithm selection.

inline void set_avg_n(int64_t avg_n)

Set average N dimension hint for algorithm selection.

inline void set_avg_k(int64_t avg_k)

Set average K dimension hint for algorithm selection.

inline void set_sm_count(int sm_count)

Set number of streaming multiprocessors to use.

inline void set_use_split_accumulator(bool use_split_accumulator)

Set split accumulator mode. Only taken into account on Hopper.

Private Members

NVTEGroupedMatmulConfig config_ = nullptr

Wrapped NVTEGroupedMatmulConfig.

struct MatmulConfigWrapper
#include <gemm.h>

C++ wrapper for NVTEMatmulConfig.

Public Functions

inline MatmulConfigWrapper()
MatmulConfigWrapper(const MatmulConfigWrapper&) = delete
MatmulConfigWrapper &operator=(const MatmulConfigWrapper&) = delete
inline MatmulConfigWrapper(MatmulConfigWrapper &&other)

Move constructor.

inline MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other)

Move-assignment operator.

inline ~MatmulConfigWrapper()
inline operator NVTEMatmulConfig() const noexcept

Get the underlying NVTEMatmulConfig.

Returns:

NVTEMatmulConfig held by this MatmulConfigWrapper.

inline void set_bias_tensor(NVTETensor bias_tensor)

Set bias tensor.

inline void set_dbias_tensor(NVTETensor dbias_tensor)

Set bias gradient tensor.

inline void set_with_gelu_epilogue(bool with_gelu_epilogue)

Set whether to compute GELU in GEMM epilogue.

inline void set_with_dgelu_epilogue(bool with_dgelu_epilogue)

Set whether to compute GELU backward in GEMM epilogue.

inline void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor)

Set auxilliary tensor for GEMM epilogue.

inline void set_use_split_accumulator(bool use_split_accumulator)

Set whether to use split accumulator for FP8 GEMM.

inline void set_sm_count(int sm_count)

Set number of streaming multiprocessors to use in GEMM kernel.

Private Members

NVTEMatmulConfig config_ = nullptr

Wrapped NVTEMatmulConfig.