CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
include/cutlass/gemm/device/gemm.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35 
38 
41 
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace device {
47 
49 
113 
116 
119 
122 
125 
128 
131 
134 
137 
140 
143 
146 
149 
152 
155 
159 template <
161  typename ElementA_,
163  typename LayoutA_,
165  typename ElementB_,
167  typename LayoutB_,
169  typename ElementC_,
171  typename LayoutC_,
173  typename ElementAccumulator_ = ElementC_,
175  typename OperatorClass_ = arch::OpClassSimt,
177  typename ArchTag_ = arch::Sm70,
179  typename ThreadblockShape_ = typename DefaultGemmConfiguration<
180  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181  ElementAccumulator_>::ThreadblockShape,
183  typename WarpShape_ = typename DefaultGemmConfiguration<
184  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185  ElementAccumulator_>::WarpShape,
187  typename InstructionShape_ = typename DefaultGemmConfiguration<
188  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189  ElementAccumulator_>::InstructionShape,
191  typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
192  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193  ElementAccumulator_>::EpilogueOutputOp,
195  typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
197  int Stages =
198  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199  ElementC_, ElementAccumulator_>::kStages,
201  int AlignmentA =
202  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
203  ElementC_, ElementAccumulator_>::kAlignmentA,
205  int AlignmentB =
206  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
207  ElementC_, ElementAccumulator_>::kAlignmentB,
209  bool SplitKSerial = false,
211  typename Operator_ = typename DefaultGemmConfiguration<
212  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
213  ElementAccumulator_>::Operator,
215  bool IsBetaZero = false>
216 class Gemm {
217  public:
218 
219  using ElementA = ElementA_;
220  using LayoutA = LayoutA_;
222  using ElementB = ElementB_;
223  using LayoutB = LayoutB_;
225  using ElementC = ElementC_;
226  using LayoutC = LayoutC_;
229  using ElementAccumulator = ElementAccumulator_;
230  using OperatorClass = OperatorClass_;
231  using ArchTag = ArchTag_;
232  using ThreadblockShape = ThreadblockShape_;
233  using WarpShape = WarpShape_;
234  using InstructionShape = InstructionShape_;
235  using EpilogueOutputOp = EpilogueOutputOp_;
236  using ThreadblockSwizzle = ThreadblockSwizzle_;
237  using Operator = Operator_;
238  static int const kStages = Stages;
239  static int const kAlignmentA = AlignmentA;
240  static int const kAlignmentB = AlignmentB;
241  static int const kAlignmentC = EpilogueOutputOp::kCount;
242  static bool const kSplitKSerial = SplitKSerial;
243  static bool const kIsBetaZero = IsBetaZero;
244 
246  using GemmKernel = typename kernel::DefaultGemm<
247  ElementA,
248  LayoutA,
249  kAlignmentA,
250  ElementB,
251  LayoutB,
252  kAlignmentB,
253  ElementC,
254  LayoutC,
257  ArchTag,
259  WarpShape,
263  kStages,
265  Operator,
266  kIsBetaZero
268 
270  struct Arguments {
271 
272  //
273  // Data members
274  //
275 
281  typename EpilogueOutputOp::Params epilogue;
283 
284  //
285  // Methods
286  //
287 
290  Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
291 
292  }
293 
297  GemmCoord problem_size_,
302  typename EpilogueOutputOp::Params epilogue_ =
303  typename EpilogueOutputOp::Params(),
304  int split_k_slices = 1
305  ):
306  problem_size(problem_size_),
307  ref_A(ref_A_),
308  ref_B(ref_B_),
309  ref_C(ref_C_),
310  ref_D(ref_D_),
311  epilogue(epilogue_),
312  split_k_slices(split_k_slices) {
313 
314  }
315  };
316 
317 private:
318 
320  typename GemmKernel::Params params_;
321 
322 public:
323 
325  Gemm() { }
326 
328  static Status can_implement(Arguments const &args) {
329 
330  if (!kSplitKSerial && args.split_k_slices > 1) {
332  }
333 
334  Status status = GemmKernel::can_implement(
335  args.problem_size,
336  args.ref_A.non_const_ref(),
337  args.ref_B.non_const_ref(),
338  args.ref_C.non_const_ref(),
339  args.ref_D
340  );
341 
342  if (status != Status::kSuccess) {
343  return status;
344  }
345 
346  return Status::kSuccess;
347  }
348 
350  static size_t get_workspace_size(Arguments const &args) {
351 
352  if (kSplitKSerial && args.split_k_slices > 1) {
353 
354  // Determine grid shape
355  ThreadblockSwizzle threadblock_swizzle;
356 
357  cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
358  args.problem_size,
359  {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
360  args.split_k_slices);
361 
362  return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
363  }
364 
365  return 0;
366  }
367 
369  Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
370 
371  // Determine grid shape
372  ThreadblockSwizzle threadblock_swizzle;
373 
374  cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
375  args.problem_size,
376  {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
377  args.split_k_slices);
378 
379  if (kSplitKSerial) {
380  if (args.split_k_slices > 1) {
381  if (!workspace) {
383  }
384 
385  size_t bytes = get_workspace_size(args);
386 
387  cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
388 
389  if (result != cudaSuccess) {
390  return Status::kErrorInternal;
391  }
392  }
393  }
394  else {
395 
396  if (args.split_k_slices > 1) {
398  }
399  }
400 
401  // Initialize the Params structure
402  params_ = typename GemmKernel::Params{
403  args.problem_size,
404  grid_shape,
405  args.ref_A.non_const_ref(),
406  args.ref_B.non_const_ref(),
407  args.ref_C.non_const_ref(),
408  args.ref_D,
409  args.epilogue,
410  static_cast<int *>(workspace)
411  };
412 
413  return Status::kSuccess;
414  }
415 
417  Status update(Arguments const &args, void *workspace = nullptr) {
418 
419  if (kSplitKSerial && args.split_k_slices > 1) {
420  if (!workspace) {
422  }
423  }
424 
425  params_.ref_A.reset(args.ref_A.non_const_ref().data());
426  params_.ref_B.reset(args.ref_B.non_const_ref().data());
427  params_.ref_C.reset(args.ref_C.non_const_ref().data());
428  params_.ref_D.reset(args.ref_D.data());
429  params_.semaphore = static_cast<int *>(workspace);
430 
431  return Status::kSuccess;
432  }
433 
435  Status run(cudaStream_t stream = nullptr) {
436 
437  ThreadblockSwizzle threadblock_swizzle;
438 
439  dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
440  dim3 block(GemmKernel::kThreadCount, 1, 1);
441 
442  cudaError_t result;
443 
444  int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
445  if (smem_size >= (48 << 10)) {
446  result = cudaFuncSetAttribute(Kernel<GemmKernel>,
447  cudaFuncAttributeMaxDynamicSharedMemorySize,
448  smem_size);
449 
450  if (result != cudaSuccess) {
451  return Status::kErrorInternal;
452  }
453 
454  result = cudaFuncSetAttribute(
455  Kernel<GemmKernel>,
456  cudaFuncAttributePreferredSharedMemoryCarveout, 100);
457 
458  if (result != cudaSuccess) {
459  return Status::kErrorInternal;
460  }
461  }
462 
463  cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
464 
465  result = cudaGetLastError();
466 
467  return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
468  }
469 
471  Status operator()(cudaStream_t stream = nullptr) {
472  return run(stream);
473  }
474 
477  Arguments const &args,
478  void *workspace = nullptr,
479  cudaStream_t stream = nullptr) {
480 
481  Status status = initialize(args, workspace);
482 
483  if (status == Status::kSuccess) {
484  status = run(stream);
485  }
486 
487  return status;
488  }
489 };
490 
492 
494 template <
496  typename ElementA_,
498  typename LayoutA_,
500  typename ElementB_,
502  typename LayoutB_,
504  typename ElementC_,
506  typename ElementAccumulator_,
508  typename OperatorClass_,
510  typename ArchTag_,
512  typename ThreadblockShape_,
514  typename WarpShape_,
516  typename InstructionShape_,
518  typename EpilogueOutputOp_,
520  typename ThreadblockSwizzle_,
522  int Stages,
524  int AlignmentA,
526  int AlignmentB,
528  bool SplitKSerial,
530  typename Operator_,
532  bool IsBetaZero>
533 class Gemm<ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
534  layout::ColumnMajor, // partially specialized on LayoutC
535  ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_,
536  WarpShape_, InstructionShape_, EpilogueOutputOp_,
537  ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial,
538  Operator_, IsBetaZero> {
539  public:
540 
541  using ElementA = ElementA_;
542  using LayoutA = LayoutA_;
544  using ElementB = ElementB_;
545  using LayoutB = LayoutB_;
547  using ElementC = ElementC_;
548  using LayoutC = layout::ColumnMajor;
551  using ElementAccumulator = ElementAccumulator_;
552  using OperatorClass = OperatorClass_;
553  using ArchTag = ArchTag_;
554  using ThreadblockShape = ThreadblockShape_;
555  using WarpShape = WarpShape_;
556  using InstructionShape = InstructionShape_;
557  using EpilogueOutputOp = EpilogueOutputOp_;
558  using ThreadblockSwizzle = ThreadblockSwizzle_;
559  using Operator = Operator_;
560  static int const kStages = Stages;
561  static int const kAlignmentA = AlignmentA;
562  static int const kAlignmentB = AlignmentB;
563  static bool const kSplitKSerial = SplitKSerial;
564  static bool const kIsBetaZero = IsBetaZero;
565 
566  using UnderlyingOperator = Gemm<
567  ElementB,
569  ElementA,
571  ElementC,
575  ArchTag,
577  WarpShape,
581  Stages,
582  kAlignmentB,
583  kAlignmentA,
584  SplitKSerial,
585  Operator,
586  kIsBetaZero
587  >;
588 
589  using UnderlyingArguments = typename UnderlyingOperator::Arguments;
591  static int const kAlignmentC = UnderlyingOperator::kAlignmentC;
592 
594  struct Arguments {
595 
596  //
597  // Data members
598  //
599 
605  typename EpilogueOutputOp::Params epilogue;
607 
608  //
609  // Methods
610  //
611 
614  Arguments() { }
615 
619  GemmCoord problem_size_,
624  typename EpilogueOutputOp::Params epilogue_ =
625  typename EpilogueOutputOp::Params(),
626  int split_k_slices = 1
627  ):
628  problem_size(problem_size_),
629  ref_A(ref_A_),
630  ref_B(ref_B_),
631  ref_C(ref_C_),
632  ref_D(ref_D_),
633  epilogue(epilogue_),
634  split_k_slices(split_k_slices) { }
635  };
636 
637 private:
638 
639  UnderlyingOperator underlying_operator_;
640 
641 public:
642 
644  Gemm() { }
645 
648  return UnderlyingArguments(
649  {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
650  {args.ref_B.data(), args.ref_B.stride(0)},
651  {args.ref_A.data(), args.ref_A.stride(0)},
652  {args.ref_C.data(), args.ref_C.stride(0)},
653  {args.ref_D.data(), args.ref_D.stride(0)},
654  args.epilogue,
655  args.split_k_slices
656  );
657  }
658 
660  static Status can_implement(Arguments const &args) {
661 
662  return UnderlyingOperator::can_implement(to_underlying_arguments(args));
663  }
664 
666  static size_t get_workspace_size(Arguments const &args) {
667 
668  return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
669  }
670 
672  Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
673 
674  return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
675  }
676 
678  Status update(Arguments const &args, void *workspace = nullptr) {
679 
680  return underlying_operator_.update(to_underlying_arguments(args), workspace);
681  }
682 
684  Status run(cudaStream_t stream = nullptr) {
685 
686  return underlying_operator_.run(stream);
687  }
688 
690  Status operator()(cudaStream_t stream = nullptr) {
691  return run(stream);
692  }
693 
696  Arguments const &args,
697  void *workspace = nullptr,
698  cudaStream_t stream = nullptr) {
699 
700  Status status = initialize(args, workspace);
701 
702  if (status == Status::kSuccess) {
703  status = run(stream);
704  }
705 
706  return status;
707  }
708 };
709 
711 
712 } // namespace device
713 } // namespace gemm
714 } // namespace cutlass
715 
Definition: default_gemm.h:116
static int const kStages
Definition: include/cutlass/gemm/device/gemm.h:238
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm.h:276
Definition: aligned_buffer.h:35
Specified problem size is not supported by operator.
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm.h:350
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm.h:328
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm.h:290
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm.h:296
Definition: include/cutlass/gemm/gemm.h:94
Definition: include/cutlass/gemm/device/gemm.h:216
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
static bool const kSplitKSerial
Definition: include/cutlass/gemm/device/gemm.h:242
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm.h:281
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm.h:277
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm.h:417
int split_k_slices
Definition: include/cutlass/gemm/device/gemm.h:282
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:471
typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kSplitKSerial, Operator, kIsBetaZero >::GemmKernel GemmKernel
Define the kernel.
Definition: include/cutlass/gemm/device/gemm.h:267
static int const kAlignmentB
Definition: include/cutlass/gemm/device/gemm.h:240
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
static int const kAlignmentA
Definition: include/cutlass/gemm/device/gemm.h:239
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm.h:279
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm.h:278
An error within CUTLASS occurred.
Template for generic CUTLASS kernel.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:476
Gemm()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm.h:325
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm.h:618
LayoutC_ LayoutC
Definition: include/cutlass/gemm/device/gemm.h:226
Argument structure.
Definition: include/cutlass/gemm/device/gemm.h:270
Definitions for GEMM structures.
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm.h:435
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm.h:369
The given workspace is null when it is required to be non-null.
Operation was successful.
static int const kAlignmentC
Definition: include/cutlass/gemm/device/gemm.h:241
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm.h:280
static bool const kIsBetaZero
Definition: include/cutlass/gemm/device/gemm.h:243
Basic include for CUTLASS.
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.