CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_tensor_op_tile_iterator_wmma.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/arch/wmma.h"
34 
35 #if defined(CUTLASS_ARCH_WMMA_ENABLED)
36 
37 #include "cutlass/wmma_array.h"
38 #include "cutlass/numeric_types.h"
39 #include "cutlass/tensor_ref.h"
40 #include "cutlass/matrix_shape.h"
41 
43 #include "cutlass/gemm/gemm.h"
44 
45 #include "cutlass/layout/matrix.h"
46 #include "cutlass/layout/tensor.h"
49 
51 #include "cutlass/fast_math.h"
52 
54 
55 namespace cutlass {
56 namespace gemm {
57 namespace warp {
58 
60 template <
62  typename Shape_,
66  typename Element_,
68  typename Layout_,
70  int OpDelta_,
72  int Threads,
74  typename Policy_>
75 class MmaTensorOpWmmaMultiplicandTileIterator;
76 
77 
86 template <
88  typename Shape_,
90  typename Element_,
92  typename Layout_,
94  int OpDelta_,
96  typename Policy_>
97 class MmaTensorOpWmmaMultiplicandTileIterator<
98  Shape_, Operand::kA, Element_, Layout_,
99  OpDelta_, 32, Policy_> {
100  public:
101 
103  using Shape = Shape_;
104 
106  static Operand const kOperand = Operand::kA;
107 
109  using Element = Element_;
110 
112  using Layout = Layout_;
113 
115  static int const kOpDelta = OpDelta_;
116 
118  using Policy = Policy_;
119 
120 
121  //
122  // Derived quantities
123  //
125  using TensorRef = TensorRef<Element, Layout>;
126 
128  using Index = typename TensorRef::Index;
129 
131  using LongIndex = typename TensorRef::LongIndex;
132 
134  using TensorCoord = typename TensorRef::TensorCoord;
135 
137  using WmmaShape = MatrixShape<
138  Policy::Operator::Shape::kM,
139  Policy::Operator::Shape::kK
140  >;
141 
143  using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
144 
146  using Iterations = MatrixShape<
147  Shape::kRow / WmmaShape::kRow,
148  1
149  >;
150 
152  using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentA, Iterations::kCount>;
153 
154 
159  static_assert(kOperand == Operand::kA,
160  "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for A operands to warp-level Mma.");
161 
166  "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
167 
169  static_assert(kOpDelta == 1,
170  "Alternative arrangements not supported at present.");
171 
173 
174 private:
175 
177  char const *pointer_;
178 
180  Index byte_offset_;
181 
183  Index stride_;
184 
186  Layout layout_;
187 
188 public:
189 
192  MmaTensorOpWmmaMultiplicandTileIterator() { }
193 
195  CUTLASS_DEVICE
196  MmaTensorOpWmmaMultiplicandTileIterator(
197  TensorRef const &ref,
198  int lane_id
199  ): pointer_(reinterpret_cast<char const*>(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) {
200 
201  }
202 
204  CUTLASS_DEVICE
205  MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
206  byte_offset_ += (offset * sizeof_bits<Element>::value) / 8;
207  return *this;
208  }
209 
212  MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) {
213 
214  Index elements_offset = layout_({tile_offset.row() * Shape::kRow, tile_offset.column() * WmmaShape::kColumn});
215 
216  byte_offset_ += (elements_offset * sizeof_bits<Element>::value) / 8;
217 
218  return *this;
219  }
220 
222  CUTLASS_DEVICE
223  MmaTensorOpWmmaMultiplicandTileIterator & operator++() {
224 
225  Index elements_offset = layout_({0, WmmaShape::kColumn});
226 
227  byte_offset_ += (elements_offset * sizeof_bits<Element>::value) / 8;
228 
229  return *this;
230  }
231 
234  MmaTensorOpWmmaMultiplicandTileIterator & operator--() {
235 
236  Index elements_offset = layout_({0, WmmaShape::kColumn});
237 
238  byte_offset_ -= (elements_offset * sizeof_bits<Element>::value) / 8;
239 
240  return *this;
241  }
242 
244  CUTLASS_DEVICE
245  MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) {
246  add_tile_offset(tile_offset);
247  return *this;
248  }
249 
251  CUTLASS_DEVICE
252  MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) {
253  add_tile_offset(-tile_offset);
254  return *this;
255  }
256 
259  void load_with_byte_offset(Fragment &frag, Index byte_offset) const {
260 
262  for (int k = 0; k < Iterations::kColumn; ++k) {
264  for (int m = 0; m < Iterations::kRow; ++m) {
265 
266  Index load_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits<Element>::value / 8;
267 
268  const WmmaDataType *ptr = reinterpret_cast<const WmmaDataType *>(pointer_ + byte_offset_ + load_byte_offset + byte_offset);
269 
270  nvcuda::wmma::load_matrix_sync(frag[m], ptr, stride_);
271 
272  }
273  }
274  }
277  void load(Fragment &frag) const {
278  load_with_byte_offset(frag, 0);
279  }
280 
283  void store_with_byte_offset(Fragment const &frag, Index byte_offset) const {
284 
286  for (int k = 0; k < Iterations::kColumn; ++k) {
288  for (int m = 0; m < Iterations::kRow; ++m) {
289 
290  Index store_byte_offset = layout_({m * WmmaShape::kRow, k * WmmaShape::kColumn}) * sizeof_bits<Element>::value / 8;
291 
292  WmmaDataType *ptr = reinterpret_cast<WmmaDataType *>(pointer_ + byte_offset_ + store_byte_offset + byte_offset);
293 
294  nvcuda::wmma::store_matrix_sync(ptr, frag[m], stride_);
295 
296  }
297  }
298  }
299 
302  void store(Fragment const &frag) const {
303  store_with_byte_offset(frag, 0);
304  }
305 
313  CUTLASS_DEVICE
314  void set_kgroup_index(int k_group) {
315  // no operation here
316  }
317 };
318 
319 
329 
330 template <
332  typename Shape_,
334  typename Element_,
336  typename Layout_,
338  int OpDelta_,
340  typename Policy_>
341 class MmaTensorOpWmmaMultiplicandTileIterator<
342  Shape_, Operand::kB, Element_, Layout_,
343  OpDelta_, 32, Policy_> {
344  public:
345 
347  using Shape = Shape_;
348 
350  static Operand const kOperand = Operand::kB;
351 
353  using Element = Element_;
354 
356  using Layout = Layout_;
357 
359  static int const kOpDelta = OpDelta_;
360 
362  using Policy = Policy_;
363 
364 
365  //
366  // Derived quantities
367  //
368 
370  using TensorRef = TensorRef<Element, Layout>;
371 
373  using Index = typename TensorRef::Index;
374 
376  using LongIndex = typename TensorRef::LongIndex;
377 
379  using TensorCoord = typename TensorRef::TensorCoord;
380 
382  using WmmaShape = MatrixShape<
383  Policy::Operator::Shape::kK,
384  Policy::Operator::Shape::kN
385  >;
386 
388  using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
389 
391  using Iterations = MatrixShape<
392  1,
393  Shape::kColumn / WmmaShape::kColumn
394  >;
395 
397  using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentB, Iterations::kCount>;
398 
399 
404  static_assert(kOperand == Operand::kB,
405  "MmaTensorOpWmmaMultiplicandTileIterator may only be instantiated for B operands to warp-level Mma.");
406 
411  "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
412 
414  static_assert(kOpDelta == 1,
415  "Alternative arrangements not supported at present.");
416 
418 
419 private:
420 
422  char const *pointer_;
423 
425  Index byte_offset_;
426 
428  Index stride_;
429 
431  Layout layout_;
432 
433 public:
434 
437  MmaTensorOpWmmaMultiplicandTileIterator() { }
438 
440  CUTLASS_DEVICE
441  MmaTensorOpWmmaMultiplicandTileIterator(
442  TensorRef const &ref,
443  int lane_id
444  ): pointer_(reinterpret_cast<char const*>(ref.data())), byte_offset_(0), stride_(ref.stride(0)), layout_(ref.stride(0)) {
445  }
446 
448  CUTLASS_DEVICE
449  MmaTensorOpWmmaMultiplicandTileIterator &add_pointer_offset(LongIndex offset) {
450 
451  byte_offset_ += (offset * sizeof_bits<Element>::value) / 8;
452 
453  return *this;
454  }
455 
458  MmaTensorOpWmmaMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) {
459 
460  Index elements_offset = layout_({tile_offset.row() * WmmaShape::kRow, tile_offset.column() * Shape::kColumn});
461 
462  byte_offset_ += (elements_offset * sizeof_bits<Element>::value) / 8;
463 
464  return *this;
465  }
466 
468  CUTLASS_DEVICE
469  MmaTensorOpWmmaMultiplicandTileIterator & operator++() {
470 
471  Index elements_offset = layout_({WmmaShape::kRow, 0});
472 
473  byte_offset_ += (elements_offset * sizeof_bits<Element>::value) / 8;
474 
475  return *this;
476  }
477 
480  MmaTensorOpWmmaMultiplicandTileIterator & operator--() {
481 
482  Index elements_offset = layout_({WmmaShape::kRow, 0});
483 
484  byte_offset_ -= (elements_offset + sizeof_bits<Element>::value) / 8;
485  return *this;
486  }
487 
489  CUTLASS_DEVICE
490  MmaTensorOpWmmaMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) {
491  add_tile_offset(tile_offset);
492  return *this;
493  }
494 
496  CUTLASS_DEVICE
497  MmaTensorOpWmmaMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) {
498  add_tile_offset(-tile_offset);
499  return *this;
500  }
501 
504  void load_with_byte_offset(Fragment &frag, Index byte_offset) const {
505 
507  for (int k = 0; k < Iterations::kRow; ++k) {
509  for (int n = 0; n < Iterations::kColumn; ++n) {
510 
511  Index load_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits<Element>::value / 8;
512 
513  const WmmaDataType *ptr = reinterpret_cast<const WmmaDataType *>(pointer_ + byte_offset_ + load_byte_offset + byte_offset);
514 
515  nvcuda::wmma::load_matrix_sync(frag[n], ptr, stride_);
516  }
517  }
518  }
521  void load(Fragment &frag) const {
522  load_with_byte_offset(frag, 0);
523  }
524 
527  void store_with_byte_offset(Fragment const &frag, Index byte_offset) const {
528 
530  for (int k = 0; k < Iterations::kRow; ++k) {
532  for (int n = 0; n < Iterations::kColumn; ++n) {
533 
534  Index store_byte_offset = layout_({k * WmmaShape::kRow, n * WmmaShape::kColumn}) * sizeof_bits<Element>::value / 8;
535 
536  WmmaDataType *ptr = reinterpret_cast<WmmaDataType *>(pointer_ + byte_offset_ + store_byte_offset + byte_offset);
537 
538  nvcuda::wmma::store_matrix_sync(ptr, frag[n], stride_);
539  }
540  }
541  }
542 
545  void store(Fragment const &frag) const {
546  store_with_byte_offset(frag, 0);
547  }
548 
556  CUTLASS_DEVICE
557  void set_kgroup_index(int k_group) {
558  // no operation here
559  }
560 };
561 
563 template <
565  typename Shape_,
567  typename Element_,
569  typename Layout_,
571  typename OpDelta_,
573  typename Policy_>
574 class MmaTensorOpWmmaAccumulatorTileIterator;
575 
586 
587 template <
589  typename Shape_,
591  typename Element_,
593  typename Layout_,
595  typename OpDelta_,
597  typename Policy_>
598 class MmaTensorOpWmmaAccumulatorTileIterator
599 {
600  public:
601 
603  using Shape = Shape_;
604 
606  using Element = Element_;
607 
609  using Layout = Layout_;
610 
612  using OpDelta = OpDelta_;
613 
615  static int const kThreads = 32;
616 
618  using Policy = Policy_;
619 
620 
621  //
622  // Derived quantities
623  //
625  using TensorRef = TensorRef<Element, Layout>;
626 
628  using Index = typename TensorRef::Index;
629 
631  using LongIndex = typename TensorRef::LongIndex;
632 
634  using TensorCoord = typename TensorRef::TensorCoord;
635 
637  using WmmaShape = MatrixShape<
638  Policy::Operator::Shape::kM,
639  Policy::Operator::Shape::kN
640  >;
641 
643  using WmmaDataType = typename cutlass::arch::CutlassToWmmaDataType<Element>::Type;
644 
646  static nvcuda::wmma::layout_t const WmmaLayout = cutlass::arch::CutlassToWmmaLayout<Layout>::value;
647 
649  using Iterations = MatrixShape<
650  Shape::kRow / WmmaShape::kRow,
651  Shape::kColumn / WmmaShape::kColumn
652  >;
653 
655  using Fragment = WmmaFragmentArray<typename Policy::Operator::FragmentC, Iterations::kCount>;
656 
664  "Supported list of memory layouts for WMMA are: RowMajor, ColumnMajor");
665 
666 private:
667 
670 
671 public:
672 
675  MmaTensorOpWmmaAccumulatorTileIterator() { }
676 
678  CUTLASS_DEVICE
679  MmaTensorOpWmmaAccumulatorTileIterator(
680  TensorRef const &ref,
681  int lane_id
682  ): ref_(ref) { }
683 
685  CUTLASS_DEVICE
686  MmaTensorOpWmmaAccumulatorTileIterator &add_pointer_offset(LongIndex offset) {
687  ref_.add_pointer_offset(offset);
688  return *this;
689  }
690 
693  MmaTensorOpWmmaAccumulatorTileIterator &add_tile_offset(TensorCoord const &tile_offset) {
694  ref_.add_coord_offset({tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn});
695  return *this;
696  }
697 
699  CUTLASS_DEVICE
700  MmaTensorOpWmmaAccumulatorTileIterator & operator++() {
701  ref_.add_coord_offset({Shape::kRow, 0});
702  return *this;
703  }
704 
707  MmaTensorOpWmmaAccumulatorTileIterator & operator--() {
708  ref_.add_coord_offset({-Shape::kRow, 0});
709  return *this;
710  }
711 
713  CUTLASS_DEVICE
714  MmaTensorOpWmmaAccumulatorTileIterator & operator+=(TensorCoord const &tile_offset) {
715  add_tile_offset(tile_offset);
716  return *this;
717  }
718 
720  CUTLASS_DEVICE
721  MmaTensorOpWmmaAccumulatorTileIterator & operator-=(TensorCoord const &tile_offset) {
722  add_tile_offset(-tile_offset);
723  return *this;
724  }
725 
728  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
729 
731  for (int m = 0; m < Iterations::kRow; ++m) {
733  for (int n = 0; n < Iterations::kColumn; ++n) {
734 
735  const WmmaDataType * ptr = reinterpret_cast<const WmmaDataType*> (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset);
736 
737  nvcuda::wmma::load_matrix_sync(frag[m * Iterations::kColumn + n], ptr, ref_.stride()[0], WmmaLayout);
738 
739  }
740  }
741  }
744  void load(Fragment &frag) const {
745  load_with_pointer_offset(frag, 0);
746  }
747 
750  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
751 
753  for (int m = 0; m < Iterations::kRow; ++m) {
755  for (int n = 0; n < Iterations::kColumn; ++n) {
756 
757  WmmaDataType * ptr = reinterpret_cast<WmmaDataType*> (ref_.data() + ref_.offset({m * WmmaShape::kRow, n * WmmaShape::kColumn}) + pointer_offset);
758 
759  nvcuda::wmma::store_matrix_sync(ptr, frag[m * Iterations::kColumn + n], ref_.stride()[0], WmmaLayout);
760  }
761  }
762  }
763 
766  void store(Fragment const &frag) const {
767  store_with_pointer_offset(frag, 0);
768  }
769 
777  CUTLASS_DEVICE
778  void set_kgroup_index(int k_group) {
779  // no operation here
780  }
781 };
782 
783 
784 
785 } // namespace warp
786 } // namespace gemm
787 } // namespace cutlass
788 
790 
791 #endif // if defined(CUTLASS_ARCH_WMMA_ENABLED)
792 
793 
static const value_t value
Definition: platform.h:261
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
Definition: aligned_buffer.h:35
static int const value
Definition: numeric_types.h:43
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
Operand
GEMM operand enumeration: D = A * B + C.
Definition: include/cutlass/gemm/gemm.h:39
Architecture-specific operators on memory added for SM75.
Defines common types used for all GEMM-like operators.
C++ features that may be otherwise unimplemented for CUDA device functions.
CUTLASS_HOST_DEVICE half_t & operator+=(half_t &lhs, half_t const &rhs)
Definition: half.h:654
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE half_t & operator-=(half_t &lhs, half_t const &rhs)
Definition: half.h:664
Defines layout functions used by TensorRef and derived classes for common 4-D and 5-D tensor formats...
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object&#39;s stride vector.
Definition: tensor_ref.h:277
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref.h:171
Defines a Shape template for matrix tiles.
CUTLASS_HOST_DEVICE half_t & operator--(half_t &lhs)
Definition: half.h:706
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
#define static_assert(__e, __m)
Definition: platform.h:153
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
Defines layout functions used by TensorRef and derived classes.
Math utilities.
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
Templates exposing architecture support for warp matrix multiply-add (WMMA) operations.
Basic include for CUTLASS.
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168