Kernel#

Defines

MARLIN_NAMESPACE_NAME marlin_moe_wna16#
MARLIN_KERNEL_PARAMS

const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp,            \

const int4 *__restrict__ b_bias_ptr, const float *__restrict__ a_scales_ptr,                                   \

const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ global_scale_ptr,                            \

const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx,                                                \

const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr,                  \

const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \

bool mul_topk_weights, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, bool has_bias,          \

bool use_atomic_add, bool use_fp32_reduce

#
namespace MARLIN_NAMESPACE_NAME

Functions

template<trt_edgellm::marlin_dtypes::ScalarTypeId const a_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const b_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const c_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const s_type_id, int const threads, int const thread_m_blocks, int const thread_n_blocks, int const thread_k_blocks, bool const m_block_size_8, int const stages, int const group_blocks, bool const is_zp_float> __global__ void Marlin (MARLIN_KERNEL_PARAMS)