CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_pipelined.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 
33 #include "cutlass/aligned_buffer.h"
34 #include "cutlass/array.h"
35 
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/matrix_shape.h"
38 
39 #include "cutlass/gemm/gemm.h"
40 
42 
43 namespace cutlass {
44 namespace gemm {
45 namespace kernel {
46 
48 
49 template <typename Mma, typename Epilogue, typename ThreadblockSwizzle>
50 __global__ void GemmPipelined(
51  cutlass::gemm::GemmCoord problem_size,
52  cutlass::gemm::GemmCoord grid_tiled_shape,
53  typename Mma::IteratorA::Params params_A,
54  typename Mma::IteratorA::TensorRef ref_A,
55  typename Mma::IteratorB::Params params_B,
56  typename Mma::IteratorB::TensorRef ref_B,
57  typename Epilogue::Params params_epilogue
58  ) {
59 
60  // Shared storage needed by threadblock-scoped matrix multiply-accumulate
61  __shared__ union {
62  typename Mma::SharedStorage main_loop;
63  typename Epilogue::SharedStorage epilogue;
64  } shared_storage;
65 
66  // Compute threadblock location
67  ThreadblockSwizzle threadblock_swizzle;
68 
69  cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset();
70 
71  if (grid_tiled_shape.m() <= tb_tile_offset.m() ||
72  grid_tiled_shape.n() <= tb_tile_offset.n()) {
73 
74  return;
75  }
76 
77  // Compute initial location in logical coordinates
78  cutlass::MatrixCoord tb_offset_A{
79  tb_tile_offset.m() * Mma::Shape::kM,
80  tb_tile_offset.k()
81  };
82 
83  cutlass::MatrixCoord tb_offset_B{
84  tb_tile_offset.k(),
85  tb_tile_offset.n() * Mma::Shape::kN
86  };
87 
88  // Compute position within threadblock
89  int tb_thread_id = threadIdx.x;
90 
91  // Construct iterators to A and B operands
92  typename Mma::IteratorA iterator_A(
93  params_A,
94  ref_A.data(),
95  {problem_size.m(), problem_size.k()},
96  tb_thread_id,
97  tb_offset_A);
98 
99  typename Mma::IteratorB iterator_B(
100  params_B,
101  ref_B.data(),
102  {problem_size.k(), problem_size.n()},
103  tb_thread_id,
104  tb_offset_B);
105 
106  int warp_id = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
107  int lane_id = threadIdx.x % 32;
108 
109  //
110  // Main loop
111  //
112 
113  // Construct thread-scoped matrix multiply
114  Mma mma(shared_storage.main_loop, tb_thread_id, warp_id, lane_id);
115 
116  typename Mma::FragmentC accumulators;
117 
118  accumulators.clear();
119 
120  // Compute threadblock-scoped matrix multiply-add
121  mma(problem_size, accumulators, iterator_A, iterator_B, accumulators);
122 
123  //
124  // Epilogue
125  //
126 
127  Epilogue epilogue(
128  params_epilogue,
129  shared_storage.epilogue,
130  tb_thread_id,
131  warp_id,
132  lane_id);
133 
134  tb_tile_offset = threadblock_swizzle.get_tile_offset();
135 
136  //assume identity swizzle
137  MatrixCoord threadblock_offset(
138  tb_tile_offset.m() * Mma::Shape::kM,
139  tb_tile_offset.n() * Mma::Shape::kN
140  );
141 
142  // run efficient epilogue
143  epilogue({problem_size.m(), problem_size.n()}, accumulators, threadblock_offset);
144 }
145 
147 
148 } // namespace kernel
149 } // namespace gemm
150 } // namespace cutlass
Definition: aligned_buffer.h:35
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
__global__ void GemmPipelined(cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord grid_tiled_shape, typename Mma::IteratorA::Params params_A, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::Params params_B, typename Mma::IteratorB::TensorRef ref_B, typename Epilogue::Params params_epilogue)
Definition: gemm_pipelined.h:50
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Defines a Shape template for matrix tiles.
AlignedBuffer is a container for trivially copyable elements suitable for use in unions and shared me...
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Basic include for CUTLASS.
Definition: matrix_coord.h:39