69 using ElementC =
typename FragmentC::value_type;
87 !(FragmentC::kElements % kElementsPerAccess),
88 "The number of accumulators must be divisible by the access size.");
118 ptr_C(ptr_C), stride_n(stride_n_ / kElementsPerAccess), stride_k(stride_k_ / kElementsPerAccess) {
130 struct alignas((kAccessSizeInBits / 8)) AccessType {
131 Array<ElementC, kElementsPerAccess> storage;
135 AccessType *pointer_;
154 pointer_(reinterpret_cast<AccessType *>(params.
ptr_C)),
170 AccessType *pointer = pointer_ +
171 tb_tile_coord.
m() * kThreadblockAccesses +
172 tb_tile_coord.
n() * stride_n_ +
173 tb_tile_coord.
k() * stride_k_;
176 AccessType
const * src_pointer =
reinterpret_cast<AccessType
const *
>(&accum);
Definition: aligned_buffer.h:35
Shared storage allocation needed by the epilogue.
Definition: epilogue_workspace.h:124
static int const kAccessSizeInBits
Optimize for 128b accesses.
Definition: epilogue_workspace.h:74
Definition: include/cutlass/gemm/gemm.h:94
static int const kWarpAccesses
Total number of vectorized accesses in warp (in units of vector)
Definition: epilogue_workspace.h:91
CUTLASS_HOST_DEVICE Params(ElementC *ptr_C, int stride_n_, int stride_k_)
Definition: epilogue_workspace.h:113
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
static int const kIterations
Number of stores per thread.
Definition: epilogue_workspace.h:84
CUTLASS_DEVICE void operator()(cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tb_tile_coord, FragmentC const &accum)
Streams the result to global memory.
Definition: epilogue_workspace.h:164
Shape_ Shape
Definition: epilogue_workspace.h:67
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
typename FragmentC::value_type ElementC
Definition: epilogue_workspace.h:69
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Defines the size of an element in bits.
Definition: numeric_types.h:42
ElementC * ptr_C
Pointer to C matrix.
Definition: epilogue_workspace.h:100
FragmentC_ FragmentC
Definition: epilogue_workspace.h:68
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_DEVICE EpilogueWorkspace(Params const ¶ms, SharedStorage &, int warp_idx, int lane_idx)
Constructor.
Definition: epilogue_workspace.h:147
static int const kWarpCount
Definition: epilogue_workspace.h:71
int stride_n
Stride between tiles along the GEMM N dimension (in units of vectors)
Definition: epilogue_workspace.h:103
static int const kElementsPerAccess
Vector length of accesses.
Definition: epilogue_workspace.h:80
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
int stride_k
Stride between tiles along the GEMM K dimension (in units of vectors)
Definition: epilogue_workspace.h:106
Parameters structure.
Definition: epilogue_workspace.h:97
Definition: epilogue_workspace.h:64
Basic include for CUTLASS.
static int const kWarpSize
Warp size from the perspective of memory operations.
Definition: epilogue_workspace.h:77
static int const kThreadblockAccesses
Total number of vectorized accesses in threadblock tile (in units of vector)
Definition: epilogue_workspace.h:94