CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
default_gemm_splitk_parallel.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
36 #pragma once
37 
38 #include "cutlass/cutlass.h"
41 
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace kernel {
47 
49 
50 template <
52  typename ElementA_,
54  typename LayoutA_,
56  int kAlignmentA,
58  typename ElementB_,
60  typename LayoutB_,
62  int kAlignmentB,
64  typename ElementC_,
66  typename LayoutC_,
68  typename ElementAccumulator,
70  typename OperatorClass,
72  typename ArchTag,
74  typename ThreadblockShape,
76  typename WarpShape,
78  typename InstructionShape,
80  typename EpilogueOutputOp,
82  typename ThreadblockSwizzle,
84  int Stages,
86  typename Operator
87 >
89 
92  using Default = DefaultGemm<
93  ElementA_,
94  LayoutA_,
95  kAlignmentA,
96  ElementB_,
97  LayoutB_,
98  kAlignmentB,
99  ElementAccumulator,
100  LayoutC_,
101  ElementAccumulator,
102  OperatorClass,
103  ArchTag,
104  ThreadblockShape,
105  WarpShape,
106  InstructionShape,
107  EpilogueOutputOp,
108  ThreadblockSwizzle,
109  Stages,
110  false,
111  Operator
112  >;
113 
115  using Mma = typename Default::Mma;
116 
118  using Epilogue = typename Default::Epilogue;
119 
122 };
123 
125 
126 } // namespace kernel
127 } // namespace gemm
128 } // namespace cutlass
129 
Definition: default_gemm.h:116
Definition: aligned_buffer.h:35
Definition: default_gemm_splitk_parallel.h:88
typename Default::Mma Mma
Define the matrix multiply operator.
Definition: default_gemm_splitk_parallel.h:115
Template for GEMM performing a reduction over K partitions in parallel.
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
typename Default::Epilogue Epilogue
Define the epilogue.
Definition: default_gemm_splitk_parallel.h:118
Definition: kernel/gemm_splitk_parallel.h:49
Basic include for CUTLASS.