Kernel#
Defines
-
MARLIN_NAMESPACE_NAME marlin_moe_wna16#
-
MARLIN_KERNEL_PARAMS
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ b_bias_ptr, const float *__restrict__ a_scales_ptr, \
const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ global_scale_ptr, \
const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, bool has_bias, \
bool use_atomic_add, bool use_fp32_reduce
#
-
namespace MARLIN_NAMESPACE_NAME
Functions
- template<trt_edgellm::marlin_dtypes::ScalarTypeId const a_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const b_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const c_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const s_type_id, int const threads, int const thread_m_blocks, int const thread_n_blocks, int const thread_k_blocks, bool const m_block_size_8, int const stages, int const group_blocks, bool const is_zp_float> __global__ void Marlin (MARLIN_KERNEL_PARAMS)