CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
device/gemm_batched.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35 
38 
41 
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace device {
47 
49 
113 
116 
119 
122 
125 
128 
131 
134 
137 
140 
143 
146 
149 
152 
155 
159 template <
161  typename ElementA_,
163  typename LayoutA_,
165  typename ElementB_,
167  typename LayoutB_,
169  typename ElementC_,
171  typename LayoutC_,
173  typename ElementAccumulator_ = ElementC_,
175  typename OperatorClass_ = arch::OpClassSimt,
177  typename ArchTag_ = arch::Sm70,
179  typename ThreadblockShape_ = typename DefaultGemmConfiguration<
180  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181  ElementAccumulator_>::ThreadblockShape,
183  typename WarpShape_ = typename DefaultGemmConfiguration<
184  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185  ElementAccumulator_>::WarpShape,
187  typename InstructionShape_ = typename DefaultGemmConfiguration<
188  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189  ElementAccumulator_>::InstructionShape,
191  typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
192  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193  ElementAccumulator_>::EpilogueOutputOp,
195  typename ThreadblockSwizzle_ = threadblock::GemmBatchedIdentityThreadblockSwizzle,
197  int Stages =
198  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199  ElementC_, ElementAccumulator_>::kStages,
201  int AlignmentA =
202  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
203  ElementC_, ElementAccumulator_>::kAlignmentA,
205  int AlignmentB =
206  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
207  ElementC_, ElementAccumulator_>::kAlignmentB,
209  typename Operator_ = typename DefaultGemmConfiguration<
210  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
211  ElementAccumulator_>::Operator
212 >
213 class GemmBatched {
214  public:
215 
216  using ElementA = ElementA_;
217  using LayoutA = LayoutA_;
219  using ElementB = ElementB_;
220  using LayoutB = LayoutB_;
222  using ElementC = ElementC_;
223  using LayoutC = LayoutC_;
226  using ElementAccumulator = ElementAccumulator_;
227  using OperatorClass = OperatorClass_;
228  using ArchTag = ArchTag_;
229  using ThreadblockShape = ThreadblockShape_;
230  using WarpShape = WarpShape_;
231  using InstructionShape = InstructionShape_;
232  using EpilogueOutputOp = EpilogueOutputOp_;
233  using ThreadblockSwizzle = ThreadblockSwizzle_;
234  static int const kStages = Stages;
235  static int const kAlignmentA = AlignmentA;
236  static int const kAlignmentB = AlignmentB;
237  static int const kAlignmentC = EpilogueOutputOp::kCount;
238  using Operator = Operator_;
239 
241  using DefaultGemmKernel = typename kernel::DefaultGemm<
242  ElementA,
243  LayoutA,
244  kAlignmentA,
245  ElementB,
246  LayoutB,
247  kAlignmentB,
248  ElementC,
249  LayoutC,
252  ArchTag,
254  WarpShape,
258  kStages,
259  false,
260  Operator,
261  false
263 
265 
267  struct Arguments {
268 
269  //
270  // Data members
271  //
272 
275  int64_t stride_A;
277  int64_t stride_B;
279  int64_t stride_C;
281  int64_t stride_D;
282  typename EpilogueOutputOp::Params epilogue;
284 
285  //
286  // Methods
287  //
288 
291  Arguments() { }
292 
296  GemmCoord problem_size_,
298  int64_t stride_A_,
300  int64_t stride_B_,
302  int64_t stride_C_,
304  int64_t stride_D_,
305  typename EpilogueOutputOp::Params epilogue_,
306  int batch_count_
307  ):
308  problem_size(problem_size_),
309  ref_A(ref_A_),
310  stride_A(stride_A_),
311  ref_B(ref_B_),
312  stride_B(stride_B_),
313  ref_C(ref_C_),
314  stride_C(stride_C_),
315  ref_D(ref_D_),
316  stride_D(stride_D_),
317  epilogue(epilogue_),
318  batch_count(batch_count_) { }
319  };
320 
321 private:
322 
324  typename GemmKernel::Params params_;
325 
326 public:
327 
330 
332  static Status can_implement(Arguments const &args) {
333 
334  if (!TensorRef_aligned(args.ref_A, kAlignmentA) || (args.stride_A % kAlignmentA)) {
336  }
337 
338  if (!TensorRef_aligned(args.ref_B, kAlignmentB) || (args.stride_B % kAlignmentB)) {
340  }
341 
342  if (!TensorRef_aligned(args.ref_C, kAlignmentC) || (args.stride_C % kAlignmentC)) {
344  }
345 
346  if (!TensorRef_aligned(args.ref_D, kAlignmentC) || (args.stride_D % kAlignmentC)) {
348  }
349 
350  if ((args.problem_size.m() % kAlignmentA) || (args.problem_size.k() % kAlignmentA) ||
351  (args.problem_size.n() % kAlignmentB) || (args.problem_size.k() % kAlignmentB) ||
352  (args.problem_size.m() % kAlignmentC) || (args.problem_size.n() % kAlignmentC)) {
353 
355  }
356 
357  return Status::kSuccess;
358  }
359 
361  static size_t get_workspace_size(Arguments const &args) {
362  return 0;
363  }
364 
366  Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
367 
368  // Determine grid shape
369  ThreadblockSwizzle threadblock_swizzle;
370 
371  cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
372  args.problem_size,
373  args.batch_count,
374  {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
375 
376  // Initialize the Params structure
377  params_ = typename GemmKernel::Params{
378  args.problem_size,
379  grid_shape,
380  args.ref_A.non_const_ref(),
381  args.stride_A,
382  args.ref_B.non_const_ref(),
383  args.stride_B,
384  args.ref_C.non_const_ref(),
385  args.stride_C,
386  args.ref_D,
387  args.stride_D,
388  args.epilogue,
389  args.batch_count
390  };
391 
392  return Status::kSuccess;
393  }
394 
396  Status update(Arguments const &args, void *workspace = nullptr) {
397 
398  params_.ref_A.reset(args.ref_A.non_const_ref().data());
399  params_.ref_B.reset(args.ref_B.non_const_ref().data());
400  params_.ref_C.reset(args.ref_C.non_const_ref().data());
401  params_.ref_D.reset(args.ref_D.data());
402 
403  return Status::kSuccess;
404  }
405 
407  Status run(cudaStream_t stream = nullptr) {
408 
409  ThreadblockSwizzle threadblock_swizzle;
410 
411  dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
412  dim3 block(GemmKernel::kThreadCount, 1, 1);
413 
414  cudaError_t result;
415 
416  int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
417  if (smem_size >= (48 << 10)) {
418  result = cudaFuncSetAttribute(Kernel<GemmKernel>,
419  cudaFuncAttributeMaxDynamicSharedMemorySize,
420  smem_size);
421 
422  if (result != cudaSuccess) {
423  return Status::kErrorInternal;
424  }
425 
426  result = cudaFuncSetAttribute(
427  Kernel<GemmKernel>,
428  cudaFuncAttributePreferredSharedMemoryCarveout, 100);
429 
430  if (result != cudaSuccess) {
431  return Status::kErrorInternal;
432  }
433  }
434 
435  cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
436 
437  result = cudaGetLastError();
438 
439  return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
440  }
441 
443  Status operator()(cudaStream_t stream = nullptr) {
444  return run(stream);
445  }
446 
449  Arguments const &args,
450  void *workspace = nullptr,
451  cudaStream_t stream = nullptr) {
452 
453  Status status = initialize(args, workspace);
454 
455  if (status == Status::kSuccess) {
456  status = run(stream);
457  }
458 
459  return status;
460  }
461 };
462 
464 
466 template <
468  typename ElementA_,
470  typename LayoutA_,
472  typename ElementB_,
474  typename LayoutB_,
476  typename ElementC_,
478  typename ElementAccumulator_,
480  typename OperatorClass_,
482  typename ArchTag_,
484  typename ThreadblockShape_,
486  typename WarpShape_,
488  typename InstructionShape_,
490  typename EpilogueOutputOp_,
492  typename ThreadblockSwizzle_,
494  int Stages,
496  int AlignmentA,
498  int AlignmentB,
499  typename Operator_
500 >
502  ElementA_,
503  LayoutA_,
504  ElementB_,
505  LayoutB_,
506  ElementC_,
507  layout::ColumnMajor,
508  ElementAccumulator_,
509  OperatorClass_,
510  ArchTag_,
511  ThreadblockShape_,
512  WarpShape_,
513  InstructionShape_,
514  EpilogueOutputOp_,
515  ThreadblockSwizzle_,
516  Stages,
517  AlignmentA,
518  AlignmentB,
519  Operator_
520 > {
521 public:
522 
523  using ElementA = ElementA_;
524  using LayoutA = LayoutA_;
526  using ElementB = ElementB_;
527  using LayoutB = LayoutB_;
529  using ElementC = ElementC_;
533  using ElementAccumulator = ElementAccumulator_;
534  using OperatorClass = OperatorClass_;
535  using ArchTag = ArchTag_;
536  using ThreadblockShape = ThreadblockShape_;
537  using WarpShape = WarpShape_;
538  using InstructionShape = InstructionShape_;
539  using EpilogueOutputOp = EpilogueOutputOp_;
540  using ThreadblockSwizzle = ThreadblockSwizzle_;
541  static int const kStages = Stages;
542 
543  static int const kAlignmentA = AlignmentA;
544  static int const kAlignmentB = AlignmentB;
545  static int const kAlignmentC = EpilogueOutputOp::kCount;
546  static bool const kSplitKSerial = false;
547 
548  //
550  ElementB,
552  ElementA,
554  ElementC,
558  ArchTag,
560  WarpShape,
564  Stages,
565  kAlignmentB,
566  kAlignmentA
567  >;
568 
569  using UnderlyingArguments = typename UnderlyingOperator::Arguments;
571 
573  struct Arguments {
574 
575  //
576  // Data members
577  //
578 
581  int64_t stride_A;
583  int64_t stride_B;
585  int64_t stride_C;
587  int64_t stride_D;
588  typename EpilogueOutputOp::Params epilogue;
590 
591  //
592  // Methods
593  //
594 
597  Arguments() { }
598 
602  GemmCoord problem_size_,
604  int64_t stride_A_,
606  int64_t stride_B_,
608  int64_t stride_C_,
610  int64_t stride_D_,
611  typename EpilogueOutputOp::Params epilogue_,
612  int batch_count_
613  ):
614  problem_size(problem_size_),
615  ref_A(ref_A_),
616  stride_A(stride_A_),
617  ref_B(ref_B_),
618  stride_B(stride_B_),
619  ref_C(ref_C_),
620  stride_C(stride_C_),
621  ref_D(ref_D_),
622  stride_D(stride_D_),
623  epilogue(epilogue_),
624  batch_count(batch_count_) { }
625  };
626 
627 private:
628 
629  UnderlyingOperator underlying_operator_;
630 
631 public:
632 
635 
638  return UnderlyingArguments(
639  {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
640  {args.ref_B.data(), args.ref_B.stride(0)},
641  args.stride_B,
642  {args.ref_A.data(), args.ref_A.stride(0)},
643  args.stride_A,
644  {args.ref_C.data(), args.ref_C.stride(0)},
645  args.stride_C,
646  {args.ref_D.data(), args.ref_D.stride(0)},
647  args.stride_D,
648  args.epilogue,
649  args.batch_count
650  );
651  }
652 
654  static Status can_implement(Arguments const &args) {
655 
656  return UnderlyingOperator::can_implement(to_underlying_arguments(args));
657  }
658 
660  static size_t get_workspace_size(Arguments const &args) {
661 
662  return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
663  }
664 
666  Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
667 
668  return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
669  }
670 
672  Status update(Arguments const &args, void *workspace = nullptr) {
673 
674  return underlying_operator_.update(to_underlying_arguments(args), workspace);
675  }
676 
678  Status run(cudaStream_t stream = nullptr) {
679 
680  return underlying_operator_.run(stream);
681  }
682 
684  Status operator()(cudaStream_t stream = nullptr) {
685  return run(stream);
686  }
687 
690  Arguments const &args,
691  void *workspace = nullptr,
692  cudaStream_t stream = nullptr) {
693 
694  Status status = initialize(args, workspace);
695 
696  if (status == Status::kSuccess) {
697  status = run(stream);
698  }
699 
700  return status;
701  }
702 
703 };
704 
706 
707 } // namespace device
708 } // namespace gemm
709 } // namespace cutlass
710 
Definition: default_gemm.h:116
static int const kAlignmentB
Definition: device/gemm_batched.h:236
TensorRef< ElementC, LayoutC > ref_D
Definition: device/gemm_batched.h:280
GemmCoord problem_size
Definition: device/gemm_batched.h:273
Definition: aligned_buffer.h:35
int64_t stride_D
Definition: device/gemm_batched.h:281
Epilogue::OutputTileIterator::TensorRef ref_D
Definition: kernel/gemm_batched.h:74
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)
Constructs an Arguments structure.
Definition: device/gemm_batched.h:295
typename kernel::DefaultGemm< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, false, Operator, false >::GemmKernel DefaultGemmKernel
Define the kernel.
Definition: device/gemm_batched.h:262
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:443
Definition: include/cutlass/gemm/gemm.h:94
Argument structure.
Definition: device/gemm_batched.h:267
Mma::IteratorB::TensorRef ref_B
Definition: kernel/gemm_batched.h:68
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: device/gemm_batched.h:361
int64_t stride_A
Definition: device/gemm_batched.h:275
Epilogue::OutputTileIterator::TensorRef ref_C
Definition: kernel/gemm_batched.h:71
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropr...
static int const kAlignmentC
Definition: device/gemm_batched.h:237
Shared memory storage structure.
Definition: kernel/gemm_batched.h:124
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: device/gemm_batched.h:291
cutlass::gemm::GemmCoord grid_tiled_shape
Definition: kernel/gemm_batched.h:63
static int const kStages
Definition: device/gemm_batched.h:234
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: device/gemm_batched.h:396
static int const kThreadCount
Definition: kernel/gemm_batched.h:58
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: device/gemm_batched.h:332
int64_t stride_C
Definition: device/gemm_batched.h:279
Parameters structure.
Definition: kernel/gemm_batched.h:61
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
operands fail alignment requirements.
An error within CUTLASS occurred.
TensorRef< ElementB const, LayoutB > ref_B
Definition: device/gemm_batched.h:276
static int const kAlignmentA
Definition: device/gemm_batched.h:235
Template for generic CUTLASS kernel.
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: device/gemm_batched.h:366
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
GemmBatched()
Constructs the GEMM.
Definition: device/gemm_batched.h:329
Top-level include for all CUTLASS numeric types.
int batch_count
Definition: device/gemm_batched.h:283
Definitions for GEMM structures.
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
cutlass::gemm::GemmCoord problem_size
Definition: kernel/gemm_batched.h:62
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
TensorRef< ElementC const, LayoutC > ref_C
Definition: device/gemm_batched.h:278
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, int64_t stride_A_, TensorRef< ElementB const, LayoutB > ref_B_, int64_t stride_B_, TensorRef< ElementC const, LayoutC > ref_C_, int64_t stride_C_, TensorRef< ElementC, LayoutC > ref_D_, int64_t stride_D_, typename EpilogueOutputOp::Params epilogue_, int batch_count_)
Constructs an Arguments structure.
Definition: device/gemm_batched.h:601
bool TensorRef_aligned(TensorRef< Element, Layout > const &ref, int alignment)
Definition: tensor_ref.h:382
Operation was successful.
Mma::IteratorA::TensorRef ref_A
Definition: kernel/gemm_batched.h:65
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:407
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
Definition: kernel/gemm_batched.h:49
int64_t stride_B
Definition: device/gemm_batched.h:277
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: device/gemm_batched.h:448
EpilogueOutputOp::Params epilogue
Definition: device/gemm_batched.h:282
Basic include for CUTLASS.
kernel::GemmBatched< typename DefaultGemmKernel::Mma, typename DefaultGemmKernel::Epilogue, ThreadblockSwizzle > GemmKernel
Definition: device/gemm_batched.h:264
Definition: device/gemm_batched.h:213
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
LayoutC_ LayoutC
Definition: device/gemm_batched.h:223
TensorRef< ElementA const, LayoutA > ref_A
Definition: device/gemm_batched.h:274