39 namespace threadblock {
45 typename ThreadblockShape,
48 typename ElementOutput,
49 int ElementsPerAccess,
50 typename ElementAccumulator
58 typename ThreadblockShape_,
61 typename ElementOutput_,
73 using WarpShape = WarpShape_;
74 static int const kPartitionsK = PartitionsK;
75 using ElementOutput = ElementOutput_;
76 static int const kElementsPerAccess = ElementsPerAccess;
85 static int const kTensorOpRows = 16;
86 static int const kWarpSize = 32;
87 static int const kInterleavedTilesM = WarpShape::kM / 32;
90 !(ThreadblockShape::kM % WarpShape::kM) &&
91 !(ThreadblockShape::kM % WarpShape::kM),
"Divisibility");
95 ThreadblockShape::kM / WarpShape::kM,
96 ThreadblockShape::kN / WarpShape::kN,
101 static int const kThreads = WarpCount::kCount * kWarpSize;
104 ThreadblockShape::kN,
117 WarpShape::kM / kTensorOpRows
139 typename ThreadblockShape_,
142 typename ElementOutput_,
143 int ElementsPerAccess
154 using WarpShape = WarpShape_;
155 static int const kPartitionsK = PartitionsK;
156 using ElementOutput = ElementOutput_;
157 static int const kElementsPerAccess = ElementsPerAccess;
166 static int const kTensorOpRows = 16;
167 static int const kWarpSize = 32;
168 static int const kInterleavedTilesM = WarpShape::kM / 32;
171 !(ThreadblockShape::kM % WarpShape::kM) &&
172 !(ThreadblockShape::kM % WarpShape::kM),
"Divisibility");
176 ThreadblockShape::kM / WarpShape::kM,
177 ThreadblockShape::kN / WarpShape::kN,
182 static int const kThreads = WarpCount::kCount * kWarpSize;
185 ThreadblockShape::kN,
198 WarpShape::kM / kTensorOpRows
float ElementAccumulator
Definition: default_thread_map_volta_tensor_op.h:158
Definition: output_tile_thread_map.h:228
Definition: aligned_buffer.h:35
Tuple defining point in output tile.
Definition: output_tile_thread_map.h:57
Epilogue for threadblock scoped GEMMs using Tensor Ops.
IEEE half-precision floating-point type.
Definition: half.h:126
ThreadblockShape_ ThreadblockShape
Definition: default_thread_map_volta_tensor_op.h:72
Defines common types used for all GEMM-like operators.
Defines the size of an element in bits.
Definition: numeric_types.h:42
ThreadblockShape_ ThreadblockShape
Definition: default_thread_map_volta_tensor_op.h:153
Shape of a matrix multiply-add operation.
Definition: include/cutlass/gemm/gemm.h:57
Defines the optimal thread map for TensorOp accumulator layouts.
Definition: default_thread_map_volta_tensor_op.h:52