49 template <
typename Mma,
typename Epilogue,
typename ThreadblockSwizzle>
53 typename Mma::IteratorA::Params params_A,
54 typename Mma::IteratorA::TensorRef ref_A,
55 typename Mma::IteratorB::Params params_B,
56 typename Mma::IteratorB::TensorRef ref_B,
57 typename Epilogue::Params params_epilogue
62 typename Mma::SharedStorage main_loop;
63 typename Epilogue::SharedStorage epilogue;
67 ThreadblockSwizzle threadblock_swizzle;
71 if (grid_tiled_shape.
m() <= tb_tile_offset.
m() ||
72 grid_tiled_shape.
n() <= tb_tile_offset.
n()) {
79 tb_tile_offset.
m() * Mma::Shape::kM,
85 tb_tile_offset.
n() * Mma::Shape::kN
89 int tb_thread_id = threadIdx.x;
92 typename Mma::IteratorA iterator_A(
95 {problem_size.
m(), problem_size.
k()},
99 typename Mma::IteratorB iterator_B(
102 {problem_size.
k(), problem_size.
n()},
106 int warp_id = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
107 int lane_id = threadIdx.x % 32;
114 Mma mma(shared_storage.main_loop, tb_thread_id, warp_id, lane_id);
116 typename Mma::FragmentC accumulators;
118 accumulators.clear();
121 mma(problem_size, accumulators, iterator_A, iterator_B, accumulators);
129 shared_storage.epilogue,
134 tb_tile_offset = threadblock_swizzle.get_tile_offset();
138 tb_tile_offset.
m() * Mma::Shape::kM,
139 tb_tile_offset.
n() * Mma::Shape::kN
143 epilogue({problem_size.
m(), problem_size.
n()}, accumulators, threadblock_offset);
Definition: aligned_buffer.h:35
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
__global__ void GemmPipelined(cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord grid_tiled_shape, typename Mma::IteratorA::Params params_A, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::Params params_B, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::Params params_epilogue)
Definition: gemm_pipelined.h:50
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines a Shape template for matrix tiles.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Basic include for CUTLASS.
Definition: matrix_coord.h:39