CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
include/cutlass/gemm/device/gemm_complex.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/arch/arch.h"
34 #include "cutlass/device_kernel.h"
35 
38 
39 #include "cutlass/gemm/kernel/default_gemm_complex.h"
41 
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace device {
47 
49 
113 
116 
119 
122 
125 
128 
131 
134 
137 
140 
143 
146 
149 
152 
155 
159 template <
161  typename ElementA_,
163  typename LayoutA_,
165  typename ElementB_,
167  typename LayoutB_,
169  typename ElementC_,
171  typename LayoutC_,
173  typename ElementAccumulator_ = ElementC_,
175  typename OperatorClass_ = arch::OpClassSimt,
177  typename ArchTag_ = arch::Sm70,
179  typename ThreadblockShape_ = typename DefaultGemmConfiguration<
180  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
181  ElementAccumulator_>::ThreadblockShape,
183  typename WarpShape_ = typename DefaultGemmConfiguration<
184  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
185  ElementAccumulator_>::WarpShape,
187  typename InstructionShape_ = typename DefaultGemmConfiguration<
188  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
189  ElementAccumulator_>::InstructionShape,
191  typename EpilogueOutputOp_ = typename DefaultGemmConfiguration<
192  OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
193  ElementAccumulator_>::EpilogueOutputOp,
195  typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle,
197  int Stages =
198  DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
199  ElementC_, ElementAccumulator_>::kStages,
205  bool SplitKSerial = false
206 >
207 class GemmComplex {
208  public:
209 
210  using ElementA = ElementA_;
211  using LayoutA = LayoutA_;
213  using ElementB = ElementB_;
214  using LayoutB = LayoutB_;
216  using ElementC = ElementC_;
217  using LayoutC = LayoutC_;
220  using ElementAccumulator = ElementAccumulator_;
221  using OperatorClass = OperatorClass_;
222  using ArchTag = ArchTag_;
223  using ThreadblockShape = ThreadblockShape_;
224  using WarpShape = WarpShape_;
225  using InstructionShape = InstructionShape_;
226  using EpilogueOutputOp = EpilogueOutputOp_;
227  using ThreadblockSwizzle = ThreadblockSwizzle_;
228  static int const kStages = Stages;
229  static ComplexTransform const kTransformA = TransformA;
230  static ComplexTransform const kTransformB = TransformB;
231  static bool const kSplitKSerial = SplitKSerial;
232 
234  using GemmKernel = typename kernel::DefaultGemmComplex<
235  ElementA,
236  LayoutA,
237  ElementB,
238  LayoutB,
239  ElementC,
240  LayoutC,
243  ArchTag,
245  WarpShape,
249  kStages,
250  kTransformA,
251  kTransformB,
252  kSplitKSerial
254 
256  struct Arguments {
257 
258  //
259  // Data members
260  //
261 
267  typename EpilogueOutputOp::Params epilogue;
269 
270  //
271  // Methods
272  //
273 
276  Arguments(): problem_size(0, 0, 0), split_k_slices(1) {
277 
278  }
279 
283  GemmCoord problem_size_,
288  typename EpilogueOutputOp::Params epilogue_ =
289  typename EpilogueOutputOp::Params(),
290  int split_k_slices = 1
291  ):
292  problem_size(problem_size_),
293  ref_A(ref_A_),
294  ref_B(ref_B_),
295  ref_C(ref_C_),
296  ref_D(ref_D_),
297  epilogue(epilogue_),
298  split_k_slices(split_k_slices) {
299 
300  }
301  };
302 
303 private:
304 
306  typename GemmKernel::Params params_;
307 
308 public:
309 
312 
314  static Status can_implement(Arguments const &args) {
315 
316  if (!kSplitKSerial && args.split_k_slices > 1) {
318  }
319 
320  return Status::kSuccess;
321  }
322 
324  static size_t get_workspace_size(Arguments const &args) {
325 
326  if (kSplitKSerial && args.split_k_slices > 1) {
327 
328  // Determine grid shape
329  ThreadblockSwizzle threadblock_swizzle;
330 
331  cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
332  args.problem_size,
333  {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
334  args.split_k_slices);
335 
336  return sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
337  }
338 
339  return 0;
340  }
341 
343  Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
344 
345  // Determine grid shape
346  ThreadblockSwizzle threadblock_swizzle;
347 
348  cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
349  args.problem_size,
350  {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
351  args.split_k_slices);
352 
353  if (kSplitKSerial) {
354  if (args.split_k_slices > 1) {
355  if (!workspace) {
357  }
358 
359  size_t bytes = get_workspace_size(args);
360 
361  cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
362 
363  if (result != cudaSuccess) {
364  return Status::kErrorInternal;
365  }
366  }
367  }
368  else {
369 
370  if (args.split_k_slices > 1) {
372  }
373  }
374 
375  // Initialize the Params structure
376  params_ = typename GemmKernel::Params{
377  args.problem_size,
378  grid_shape,
379  args.ref_A.non_const_ref(),
380  args.ref_B.non_const_ref(),
381  args.ref_C.non_const_ref(),
382  args.ref_D,
383  args.epilogue,
384  static_cast<int *>(workspace)
385  };
386 
387  return Status::kSuccess;
388  }
389 
391  Status update(Arguments const &args, void *workspace = nullptr) {
392 
393  if (kSplitKSerial && args.split_k_slices > 1) {
394  if (!workspace) {
396  }
397  }
398 
399  params_.ref_A.reset(args.ref_A.non_const_ref().data());
400  params_.ref_B.reset(args.ref_B.non_const_ref().data());
401  params_.ref_C.reset(args.ref_C.non_const_ref().data());
402  params_.ref_D.reset(args.ref_D.data());
403  params_.semaphore = static_cast<int *>(workspace);
404 
405  return Status::kSuccess;
406  }
407 
409  Status run(cudaStream_t stream = nullptr) {
410 
411  ThreadblockSwizzle threadblock_swizzle;
412 
413  dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
414  dim3 block(GemmKernel::kThreadCount, 1, 1);
415 
416  cudaError_t result;
417 
418  int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
419  if (smem_size >= (48 << 10)) {
420  result = cudaFuncSetAttribute(Kernel<GemmKernel>,
421  cudaFuncAttributeMaxDynamicSharedMemorySize,
422  smem_size);
423 
424  if (result != cudaSuccess) {
425  return Status::kErrorInternal;
426  }
427 
428  result = cudaFuncSetAttribute(
429  Kernel<GemmKernel>,
430  cudaFuncAttributePreferredSharedMemoryCarveout, 100);
431 
432  if (result != cudaSuccess) {
433  return Status::kErrorInternal;
434  }
435  }
436 
437  cutlass::Kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
438 
439  result = cudaGetLastError();
440 
441  return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
442  }
443 
445  Status operator()(cudaStream_t stream = nullptr) {
446  return run(stream);
447  }
448 
451  Arguments const &args,
452  void *workspace = nullptr,
453  cudaStream_t stream = nullptr) {
454 
455  Status status = initialize(args, workspace);
456 
457  if (status == Status::kSuccess) {
458  status = run(stream);
459  }
460 
461  return status;
462  }
463 };
464 
466 
468 template <
470  typename ElementA_,
472  typename LayoutA_,
474  typename ElementB_,
476  typename LayoutB_,
478  typename ElementC_,
480  typename ElementAccumulator_,
482  typename OperatorClass_,
484  typename ArchTag_,
486  typename ThreadblockShape_,
488  typename WarpShape_,
490  typename InstructionShape_,
492  typename EpilogueOutputOp_,
494  typename ThreadblockSwizzle_,
496  int Stages,
498  ComplexTransform TransformA,
500  ComplexTransform TransformB,
502  bool SplitKSerial
503 >
505  ElementA_,
506  LayoutA_,
507  ElementB_,
508  LayoutB_,
509  ElementC_,
510  layout::ColumnMajor, // partially specialized on LayoutC
511  ElementAccumulator_,
512  OperatorClass_,
513  ArchTag_,
514  ThreadblockShape_,
515  WarpShape_,
516  InstructionShape_,
517  EpilogueOutputOp_,
518  ThreadblockSwizzle_,
519  Stages,
520  TransformA,
521  TransformB,
522  SplitKSerial
523 > {
524 public:
525 
526  using ElementA = ElementA_;
527  using LayoutA = LayoutA_;
529  using ElementB = ElementB_;
530  using LayoutB = LayoutB_;
532  using ElementC = ElementC_;
533  using LayoutC = layout::ColumnMajor;
536  using ElementAccumulator = ElementAccumulator_;
537  using OperatorClass = OperatorClass_;
538  using ArchTag = ArchTag_;
539  using ThreadblockShape = ThreadblockShape_;
540  using WarpShape = WarpShape_;
541  using InstructionShape = InstructionShape_;
542  using EpilogueOutputOp = EpilogueOutputOp_;
543  using ThreadblockSwizzle = ThreadblockSwizzle_;
544  static int const kStages = Stages;
545  static bool const kSplitKSerial = SplitKSerial;
546 
548  ElementB,
550  ElementA,
552  ElementC,
556  ArchTag,
558  WarpShape,
562  Stages,
563  TransformA,
564  TransformB,
565  SplitKSerial
566  >;
567 
568  using UnderlyingArguments = typename UnderlyingOperator::Arguments;
570 
572  struct Arguments {
573 
574  //
575  // Data members
576  //
577 
583  typename EpilogueOutputOp::Params epilogue;
585 
586  //
587  // Methods
588  //
589 
592  Arguments() { }
593 
597  GemmCoord problem_size_,
602  typename EpilogueOutputOp::Params epilogue_ =
603  typename EpilogueOutputOp::Params(),
604  int split_k_slices = 1
605  ):
606  problem_size(problem_size_),
607  ref_A(ref_A_),
608  ref_B(ref_B_),
609  ref_C(ref_C_),
610  ref_D(ref_D_),
611  epilogue(epilogue_),
612  split_k_slices(split_k_slices) { }
613  };
614 
615 private:
616 
617  UnderlyingOperator underlying_operator_;
618 
619 public:
620 
623 
626  return UnderlyingArguments(
627  {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
628  {args.ref_B.data(), args.ref_B.stride(0)},
629  {args.ref_A.data(), args.ref_A.stride(0)},
630  {args.ref_C.data(), args.ref_C.stride(0)},
631  {args.ref_D.data(), args.ref_D.stride(0)},
632  args.epilogue,
633  args.split_k_slices
634  );
635  }
636 
638  static Status can_implement(Arguments const &args) {
639 
640  return UnderlyingOperator::can_implement(to_underlying_arguments(args));
641  }
642 
644  static size_t get_workspace_size(Arguments const &args) {
645 
646  return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args));
647  }
648 
650  Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
651 
652  return underlying_operator_.initialize(to_underlying_arguments(args), workspace);
653  }
654 
656  Status update(Arguments const &args, void *workspace = nullptr) {
657 
658  return underlying_operator_.update(to_underlying_arguments(args), workspace);
659  }
660 
662  Status run(cudaStream_t stream = nullptr) {
663 
664  return underlying_operator_.run(stream);
665  }
666 
668  Status operator()(cudaStream_t stream = nullptr) {
669  return run(stream);
670  }
671 
674  Arguments const &args,
675  void *workspace = nullptr,
676  cudaStream_t stream = nullptr) {
677 
678  Status status = initialize(args, workspace);
679 
680  if (status == Status::kSuccess) {
681  status = run(stream);
682  }
683 
684  return status;
685  }
686 };
687 
689 
690 } // namespace device
691 } // namespace gemm
692 } // namespace cutlass
693 
Status operator()(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:450
static ComplexTransform const kTransformA
Definition: include/cutlass/gemm/device/gemm_complex.h:229
TensorRef< ElementA const, LayoutA > ref_A
Definition: include/cutlass/gemm/device/gemm_complex.h:263
Definition: include/cutlass/gemm/device/gemm_complex.h:207
Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: complex.h:43
Status run(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:409
Specified problem size is not supported by operator.
GemmCoord problem_size
Definition: include/cutlass/gemm/device/gemm_complex.h:262
static Status can_implement(Arguments const &args)
Determines whether the GEMM can execute the given problem.
Definition: include/cutlass/gemm/device/gemm_complex.h:314
static int const kStages
Definition: include/cutlass/gemm/device/gemm_complex.h:228
TensorRef< ElementB const, LayoutB > ref_B
Definition: include/cutlass/gemm/device/gemm_complex.h:264
Definition: include/cutlass/gemm/gemm.h:94
EpilogueOutputOp::Params epilogue
Definition: include/cutlass/gemm/device/gemm_complex.h:267
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Status initialize(Arguments const &args, void *workspace=nullptr, cudaStream_t stream=nullptr)
Initializes GEMM state from arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:343
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Argument structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:256
static size_t get_workspace_size(Arguments const &args)
Gets the workspace size.
Definition: include/cutlass/gemm/device/gemm_complex.h:324
LayoutC_ LayoutC
Definition: include/cutlass/gemm/device/gemm_complex.h:217
CUTLASS_HOST_DEVICE Arguments()
Default ctor.
Definition: include/cutlass/gemm/device/gemm_complex.h:276
Defines transposes of matrix layouts.
Definition: layout/matrix.h:921
An error within CUTLASS occurred.
Status operator()(cudaStream_t stream=nullptr)
Runs the kernel using initialized state.
Definition: include/cutlass/gemm/device/gemm_complex.h:445
Template for generic CUTLASS kernel.
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
GemmComplex()
Constructs the GEMM.
Definition: include/cutlass/gemm/device/gemm_complex.h:311
Definitions for GEMM structures.
TensorRef< ElementC, LayoutC > ref_D
Definition: include/cutlass/gemm/device/gemm_complex.h:266
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
TensorRef< ElementC const, LayoutC > ref_C
Definition: include/cutlass/gemm/device/gemm_complex.h:265
Status update(Arguments const &args, void *workspace=nullptr)
Lightweight update given a subset of arguments.
Definition: include/cutlass/gemm/device/gemm_complex.h:391
static bool const kSplitKSerial
Definition: include/cutlass/gemm/device/gemm_complex.h:231
The given workspace is null when it is required to be non-null.
Operation was successful.
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Implements several possible threadblock-swizzling functions mapping blockIdx to GEMM problems...
Defines tags for architecture-specific configurations.
int split_k_slices
Definition: include/cutlass/gemm/device/gemm_complex.h:268
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:282
typename kernel::DefaultGemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, kStages, kTransformA, kTransformB, kSplitKSerial >::GemmKernel GemmKernel
Define the kernel.
Definition: include/cutlass/gemm/device/gemm_complex.h:253
static ComplexTransform const kTransformB
Definition: include/cutlass/gemm/device/gemm_complex.h:230
Basic include for CUTLASS.
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
CUTLASS_HOST_DEVICE Arguments(GemmCoord problem_size_, TensorRef< ElementA const, LayoutA > ref_A_, TensorRef< ElementB const, LayoutB > ref_B_, TensorRef< ElementC const, LayoutC > ref_C_, TensorRef< ElementC, LayoutC > ref_D_, typename EpilogueOutputOp::Params epilogue_=typename EpilogueOutputOp::Params(), int split_k_slices=1)
Constructs an Arguments structure.
Definition: include/cutlass/gemm/device/gemm_complex.h:596