CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
device/gemm_splitk_parallel.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35 
38 
41 
45 
47 
48 namespace cutlass {
49 namespace gemm {
50 namespace device {
51 
53 
58 template <
60  typename ElementA_,
62  typename LayoutA_,
64  typename ElementB_,
66  typename LayoutB_,
68  typename ElementC_,
70  typename LayoutC_,
72  typename ElementAccumulator_ = ElementC_,
74  typename OperatorClass_ = arch::OpClassSimt,
76  typename ArchTag_ = arch::Sm70,
78  typename ThreadblockShape_ = typename DefaultGemmConfiguration<
79  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
80  ElementAccumulator_>::ThreadblockShape,
82  typename WarpShape_ = typename DefaultGemmConfiguration<
83  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
84  ElementAccumulator_>::WarpShape,
86  typename InstructionShape_ = typename DefaultGemmConfiguration<
87  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
88  ElementAccumulator_>::InstructionShape,
90  typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
91  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
92  ElementAccumulator_>::EpilogueOutputOp,
94  typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert<
95  ElementAccumulator_,
96  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
97  ElementAccumulator_,
98  ElementAccumulator_>::EpilogueOutputOp::kCount,
99  ElementAccumulator_>,
101  typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd<
102  ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator,
103  EpilogueOutputOp_::kCount>,
105  typename ThreadblockSwizzle_ =
106  threadblock::GemmSplitKHorizontalThreadblockSwizzle,
108  int Stages =
109  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
110  ElementC_, ElementAccumulator_>::kStages,
112  int kAlignmentA =
113  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
114  ElementC_, ElementAccumulator_>::kAlignmentA,
116  int kAlignmentB =
117  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
118  ElementC_, ElementAccumulator_>::kAlignmentB,
120  typename Operator_ = typename DefaultGemmConfiguration<
121  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
122  ElementAccumulator_>::Operator>
124  public:
125 
126  using ElementA = ElementA_;
127  using LayoutA = LayoutA_;
128  using ElementB = ElementB_;
129  using LayoutB = LayoutB_;
130  using ElementC = ElementC_;
131  using LayoutC = LayoutC_;
132  using ElementAccumulator = ElementAccumulator_;
133  using OperatorClass = OperatorClass_;
134  using ArchTag = ArchTag_;
135  using ThreadblockShape = ThreadblockShape_;
136  using WarpShape = WarpShape_;
137  using InstructionShape = InstructionShape_;
138  using ConvertScaledOp = ConvertScaledOp_;
139  using EpilogueOutputOp = EpilogueOutputOp_;
140  using ReductionOp = ReductionOp_;
141  using ThreadblockSwizzle = ThreadblockSwizzle_;
142  using Operator = Operator_;
143  static int const kStages = Stages;
144 
147  ElementA,
148  LayoutA,
149  kAlignmentA,
150  ElementB,
151  LayoutB,
152  kAlignmentB,
154  LayoutC,
157  ArchTag,
159  WarpShape,
163  kStages,
164  Operator
166 
172  >;
173 
174  //
175  //
176  //
177 
179  struct Arguments {
180 
181  //
182  // Data members
183  //
184 
190  typename EpilogueOutputOp::Params epilogue;
192  typename ConvertScaledOp::Params convert;
193  typename ReductionOp::Params reduction;
194 
195  //
196  // Methods
197  //
198 
201  Arguments() { }
202 
206  GemmCoord problem_size_,
211  typename EpilogueOutputOp::Params epilogue_ =
212  typename EpilogueOutputOp::Params(),
213  int split_k_slices = 1,
214  typename ConvertScaledOp::Params convert_ =
215  typename ConvertScaledOp::Params(),
216  typename ReductionOp::Params reduction_ =
217  typename ReductionOp::Params()
218  ):
219  problem_size(problem_size_),
220  ref_A(ref_A_),
221  ref_B(ref_B_),
222  ref_C(ref_C_),
223  ref_D(ref_D_),
224  epilogue(epilogue_),
225  split_k_slices(split_k_slices),
226  convert(convert_),
227  reduction(reduction_) { }
228  };
229 
230 private:
231 
233  typename GemmKernel::Params gemm_params_;
234 
236  typename ReductionKernel::Params reduction_params_;
237 
238 public:
239 
242 
244  static Status can_implement(Arguments const &args) {
245 
246  // TODO
247 
248  return Status::kSuccess;
249  }
250 
252  static size_t get_workspace_size(Arguments const &args) {
253 
254  // Determine grid shape
255  ThreadblockSwizzle threadblock_swizzle;
256 
257  cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
258  args.problem_size,
259  {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
260  args.split_k_slices);
261 
262  return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k();
263  }
264 
266  Status initialize(Arguments const &args, void *workspace) {
267 
268  // Determine grid shape
269  ThreadblockSwizzle threadblock_swizzle;
270 
271  cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
272  args.problem_size,
273  {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
274  args.split_k_slices);
275 
276  // Define a reference to the workspace - this is an aligned region in device memory.
277  if (!workspace) {
279  }
280 
282  static_cast<ElementAccumulator_ *>(workspace),
283  args.problem_size.n());
284 
285  int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n());
286 
287  // Initialize the Params structure
288  gemm_params_ = typename GemmKernel::Params{
289  args.problem_size,
290  grid_shape,
291  args.ref_A.non_const_ref(),
292  args.ref_B.non_const_ref(),
293  ref_workspace,
294  args.convert,
295  partition_stride
296  };
297 
298  reduction_params_ = typename ReductionKernel::Params(
299  args.problem_size.mn(),
300  grid_shape.k(),
301  partition_stride,
302  ref_workspace,
303  args.ref_D,
304  args.ref_C.non_const_ref(),
305  args.epilogue
306  );
307 
308  return Status::kSuccess;
309  }
310 
312  Status update(Arguments const &args, void *workspace = nullptr) {
313 
314  if (!workspace) {
316  }
317 
318  gemm_params_.ref_A.reset(args.ref_A.data());
319  gemm_params_.ref_B.reset(args.ref_B.data());
320  gemm_params_.ref_D.reset(workspace);
321 
322  reduction_params_.ref_D.reset(args.ref_D.data());
323  reduction_params_.ref_C.reset(args.ref_C.data());
324 
325  return Status::kSuccess;
326  }
327 
329  Status run(cudaStream_t stream = nullptr) {
330 
331  //
332  // Launch GEMM kernel
333  //
334 
335  ThreadblockSwizzle threadblock_swizzle;
336 
337  dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape);
338  dim3 block(GemmKernel::kThreadCount, 1, 1);
339 
340  cudaError_t result;
341 
342  int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
343  if (smem_size >= (48 << 10)) {
344 
345  result = cudaFuncSetAttribute(
346  Kernel<GemmKernel>,
347  cudaFuncAttributeMaxDynamicSharedMemorySize,
348  smem_size);
349 
350  if (result != cudaSuccess) {
351  return Status::kErrorInternal;
352  }
353 
354  result = cudaFuncSetAttribute(
355  Kernel<GemmKernel>,
356  cudaFuncAttributePreferredSharedMemoryCarveout, 100);
357 
358  if (result != cudaSuccess) {
359  return Status::kErrorInternal;
360  }
361  }
362 
363  Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(gemm_params_);
364 
365  result = cudaGetLastError();
366 
367  if (result != cudaSuccess) {
368  return Status::kErrorInternal;
369  }
370 
371  //
372  // Launch reduction kernel
373  //
374 
376  grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn());
377 
378  Kernel<ReductionKernel><<< grid, block, 0, stream >>>(reduction_params_);
379 
380  result = cudaGetLastError();
381 
382  if (result != cudaSuccess) {
383  return Status::kErrorInternal;
384  }
385 
386  return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
387  }
388 
390  Status operator()(cudaStream_t stream = nullptr) {
391  return run(stream);
392  }
393 
396  Arguments const &args,
397  void *workspace = nullptr,
398  cudaStream_t stream = nullptr) {
399 
400  Status status = initialize(args, workspace);
401 
402  if (status == Status::kSuccess) {
403  status = run(stream);
404  }
405 
406  return status;
407  }
408 };
409 
411 
413 template <
415  typename ElementA_,
417  typename LayoutA_,
419  typename ElementB_,
421  typename LayoutB_,
423  typename ElementC_,
425  typename ElementAccumulator_,
427  typename OperatorClass_,
429  typename ArchTag_,
431  typename ThreadblockShape_,
433  typename WarpShape_,
435  typename InstructionShape_,
437  typename EpilogueOutputOp_,
439  typename ConvertScaledOp_,
441  typename ReductionOp_,
443  typename ThreadblockSwizzle_,
445  int Stages, int kAlignmentA, int kAlignmentB,
447  typename Operator_>
448 class GemmSplitKParallel<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
449  layout::ColumnMajor, ElementAccumulator_,
450  OperatorClass_, ArchTag_, ThreadblockShape_,
451  WarpShape_, InstructionShape_, EpilogueOutputOp_,
452  ConvertScaledOp_, ReductionOp_, ThreadblockSwizzle_,
453  Stages, kAlignmentA, kAlignmentB, Operator_> {
454  public:
455 
456  using ElementA = ElementA_;
457  using LayoutA = LayoutA_;
458  using ElementB = ElementB_;
459  using LayoutB = LayoutB_;
460  using ElementC = ElementC_;
461  using LayoutC = layout::ColumnMajor;
462  using ElementAccumulator = ElementAccumulator_;
463  using OperatorClass = OperatorClass_;
464  using ArchTag = ArchTag_;
465  using ThreadblockShape = ThreadblockShape_;
466  using WarpShape = WarpShape_;
467  using InstructionShape = InstructionShape_;
468  using ConvertScaledOp = ConvertScaledOp_;
469  using EpilogueOutputOp = EpilogueOutputOp_;
470  using ReductionOp = ReductionOp_;
471  using ThreadblockSwizzle = ThreadblockSwizzle_;
472  using Operator = Operator_;
473  static int const kStages = Stages;
474 
476  ElementB,
478  ElementA,
480  ElementC,
484  ArchTag,
486  WarpShape,
490  ReductionOp,
492  Stages,
493  kAlignmentA,
494  kAlignmentB,
495  Operator
496  >;
497 
498  using UnderlyingArguments = typename UnderlyingOperator::Arguments;
501 
503  struct Arguments {
504 
505  //
506  // Data members
507  //
508 
514  typename EpilogueOutputOp::Params epilogue;
516  typename ConvertScaledOp::Params convert;
517  typename ReductionOp::Params reduction;
518 
519  //
520  // Methods
521  //
522 
525  Arguments() { }
526 
530  GemmCoord problem_size_,
535  typename EpilogueOutputOp::Params epilogue_ =
536  typename EpilogueOutputOp::Params(),
537  int split_k_slices = 1,
538  typename ConvertScaledOp::Params convert_ =
539  typename ConvertScaledOp::Params(),
540  typename ReductionOp::Params reduction_ =
541  typename ReductionOp::Params()
542  ):
543  problem_size(problem_size_),
544  ref_A(ref_A_),
545  ref_B(ref_B_),
546  ref_C(ref_C_),
547  ref_D(ref_D_),
548  epilogue(epilogue_),
549  split_k_slices(split_k_slices),
550  convert(convert_),
551  reduction(reduction_) { }
552  };
553 
554 private:
555 
557  UnderlyingOperator underlying_operator_;
558 
559 public:
560 
563 
566  return UnderlyingArguments(
567  {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
568  {args.ref_B.data(), args.ref_B.stride(0)},
569  {args.ref_A.data(), args.ref_A.stride(0)},
570  {args.ref_C.data(), args.ref_C.stride(0)},
571  {args.ref_D.data(), args.ref_D.stride(0)},
572  args.epilogue,
573  args.split_k_slices,
574  args.convert,
575  args.reduction
576  );
577  }
578 
580  static Status can_implement(Arguments const &args) {
581 
582  return UnderlyingOperator::can_implement(to_underlying_arguments(args));
583  }
584 
586  static size_t get_workspace_size(Arguments const &args) {
587 
588  return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
589  }
590 
592  Status initialize(Arguments const &args, void *workspace) {
593 
594  return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
595  }
596 
598  Status update(Arguments const &args, void *workspace = nullptr) {
599 
600  return underlying_operator_.update(to_underlying_arguments(args), workspace);
601  }
602 
604  Status run(cudaStream_t stream = nullptr) {
605 
606  return underlying_operator_.run(stream);
607  }
608 
610  Status operator()(cudaStream_t stream = nullptr) {
611  return run(stream);
612  }
613 
616  Arguments const &args,
617  void *workspace = nullptr,
618  cudaStream_t stream = nullptr) {
619 
620  Status status = initialize(args, workspace);
621 
622  if (status == Status::kSuccess) {
623  status = run(stream);
624  }
625 
626  return status;
627  }
628 };
629 
631 
632 } // namespace device
633 } // namespace gemm
634 } // namespace cutlass
635 
Definition: conversion_op.h:53
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Definition: aligned_buffer.h:35
Definition: default_gemm_splitk_parallel.h:88
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:395
static int const kStages
Definition: device/gemm_splitk_parallel.h:143
static CUTLASS_HOST_DEVICE dim3 block_shape()
Determines the threadblock shape.
Definition: reduce_split_k.h:138
Kernel performing a reduction over densely packed tensors in global memory.
Definition: include/cutlass/gemm/gemm.h:94
Functor performing conversion operations used by epilogues.
int split_k_slices
Definition: device/gemm_splitk_parallel.h:191
ReductionOp::Params reduction
Definition: device/gemm_splitk_parallel.h:193
Mixed-precision reduction.
Definition: reduction_operators.h:50
Params structure.
Definition: reduce_split_k.h:80
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_splitk_parallel.h:201
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
ConvertScaledOp::Params convert
Definition: device/gemm_splitk_parallel.h:192
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_splitk_parallel.h:244
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_splitk_parallel.h:189
GemmSplitKParallel()
Constructs the GEMM.
Definition: device/gemm_splitk_parallel.h:241
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
GemmCoord problem_size
Definition: device/gemm_splitk_parallel.h:185
typename kernel::DefaultGemmSplitKParallel< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, ConvertScaledOp, ThreadblockSwizzle, kStages, Operator >::GemmKernel GemmKernel
GEMM kernel.
Definition: device/gemm_splitk_parallel.h:165
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:390
An error within CUTLASS occurred.
Template for generic CUTLASS kernel.
Kernel performing a reduction over densely packed tensors in global memory.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Definition: reduce_split_k.h:55
static CUTLASS_HOST_DEVICE dim3 grid_shape(cutlass::MatrixCoord problem_size)
Computes the grid size given a chosen threadblock shape.
Definition: reduce_split_k.h:128
Definitions for GEMM structures.
CUTLASS_HOST_DEVICE NonConstTensorRef non_const_ref() const
Definition: tensor_ref.h:229
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_splitk_parallel.h:252
Definition: device/gemm_splitk_parallel.h:123
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_splitk_parallel.h:188
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure.
Definition: device/gemm_splitk_parallel.h:205
The given workspace is null when it is required to be non-null.
Operation was successful.
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_splitk_parallel.h:187
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_splitk_parallel.h:312
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_splitk_parallel.h:329
EpilogueOutputOp::Params epilogue
Definition: device/gemm_splitk_parallel.h:190
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
Argument structure.
Definition: device/gemm_splitk_parallel.h:179
LayoutC_ LayoutC
Definition: device/gemm_splitk_parallel.h:131
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1, typename ConvertScaledOp::Params convert_=typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_=typename ReductionOp::Params())
Constructs an Arguments structure.
Definition: device/gemm_splitk_parallel.h:529
Status initialize(Arguments const &args, void *workspace)
Initializes GEMM state from arguments.
Definition: device/gemm_splitk_parallel.h:266
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_splitk_parallel.h:186
Basic include for CUTLASS.
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.