Marlin Template#
Defines
-
MARLIN_NAMESPACE_NAME marlin_moe_wna16#
-
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t)
static_assert(std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
#
-
namespace MARLIN_NAMESPACE_NAME
Functions
- template<int count, trt_edgellm::marlin_dtypes::ScalarTypeId type_id> inline __device__ void ldsm (typename MarlinScalarType< type_id >::FragA &frag_a, void const *smem_ptr)
- template<trt_edgellm::marlin_dtypes::ScalarTypeId type_id> inline __device__ void scale (typename MarlinScalarType< type_id >::FragB &frag_b, typename MarlinScalarType< type_id >::FragS &frag_s, int i)
- template<trt_edgellm::marlin_dtypes::ScalarTypeId type_id> inline __device__ void scale4 (typename MarlinScalarType< type_id >::FragB &frag_b, typename MarlinScalarType< type_id >::FragS &frag_s_1, typename MarlinScalarType< type_id >::FragS &frag_s_2, typename MarlinScalarType< type_id >::FragS &frag_s_3, typename MarlinScalarType< type_id >::FragS &frag_s_4, int i)
- template<trt_edgellm::marlin_dtypes::ScalarTypeId type_id> inline __device__ void scale_float (float *c, typename MarlinScalarType< type_id >::FragS &s)
- inline __device__ void barrier_acquire (int *lock, int count)
- inline __device__ void barrier_release (int *lock, bool reset=false)
- inline __device__ void wait_negative_and_add (int *lock)
- template<trt_edgellm::marlin_dtypes::ScalarTypeId const a_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const b_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const c_type_id, trt_edgellm::marlin_dtypes::ScalarTypeId const s_type_id, int const threads, int const thread_m_blocks, int const thread_n_blocks, int const thread_k_blocks, bool const m_block_size_8, int const stages, int const group_blocks, bool const is_zp_float> __global__ void Marlin (int4 const *__restrict__ A, int4 const *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, int4 const *__restrict__ b_bias_ptr, float const *__restrict__ a_scales_ptr, int4 const *__restrict__ scales_ptr, uint16_t const *__restrict__ global_scale_ptr, int4 const *__restrict__ zp_ptr, int const *__restrict__ g_idx, int32_t const *__restrict__ sorted_token_ids_ptr, int32_t const *__restrict__ expert_ids_ptr, int32_t const *__restrict__ num_tokens_past_padded_ptr, float const *__restrict__ topk_weights_ptr, int top_k, bool mul_topk_weights, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, bool has_bias, bool use_atomic_add, bool use_fp32_reduce)