CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
mma_simt_tile_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/array.h"
34 #include "cutlass/tensor_ref.h"
35 #include "cutlass/matrix_shape.h"
36 #include "cutlass/layout/matrix.h"
37 
38 #include "cutlass/gemm/gemm.h"
40 
42 
43 namespace cutlass {
44 namespace gemm {
45 namespace warp {
46 
48 
53 template <
55  typename Shape_,
59  typename Element_,
61  typename Layout_,
63  typename Policy_,
65  int PartitionsK = 1,
67  int PartitionGroupSize = 1
68 >
70 
72 
77 template <
79  typename Shape_,
81  typename Element_,
83  typename Policy_,
85  int PartitionsK,
87  int PartitionGroupSize
88 >
89 class MmaSimtTileIterator<Shape_, Operand::kA, Element_, layout::ColumnMajor, Policy_, PartitionsK, PartitionGroupSize> {
90 public:
91 
93  using Shape = Shape_;
94 
96  static Operand const kOperand = Operand::kA;
97 
99  using Element = Element_;
100 
103 
105  using Policy = Policy_;
106 
109 
111  using Index = typename TensorRef::Index;
112 
114  using LongIndex = typename TensorRef::LongIndex;
115 
118 
119  //
120  // Derived quantities
121  //
122 
123  static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
124  "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
125 
126  static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
127  static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
128  static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
129  static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
130 
132  using ThreadShape = MatrixShape<
133  Shape::kRow / Policy::WarpShape::kRow,
134  Shape::kColumn
135  >;
136 
137  static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM),
138  "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
139 
141  using Iterations = MatrixShape<
142  ThreadShape::kRow / Policy::LaneMmaShape::kM,
143  ThreadShape::kColumn
144  >;
145 
147  using Fragment = Array<Element, ThreadShape::kCount>;
148 
149 private:
150 
153 
154 public:
155 
159 
163  TensorRef ref,
164  int lane_id
165  ) {
166 
167  // compute offset based on thread ID and lane layout
168  typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
169 
170  MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
171  MatrixCoord(Policy::LaneMmaShape::kM, 0);
172 
173  ref.add_coord_offset(lane_offset);
174 
175  ref_.reset(
176  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(ref.data()),
177  ref.stride(0) / Policy::LaneMmaShape::kM);
178  }
179 
180 
184  ref_.add_pointer_offset(offset);
185  return *this;
186  }
187 
191 
192  ref_.add_coord_offset({
193  coord.row() * Shape::kRow / Policy::LaneMmaShape::kM,
194  coord.column() * Shape::kColumn});
195 
196  return *this;
197  }
198 
202 
203  ref_.add_coord_offset({0, Shape::kColumn});
204 
205  return *this;
206  }
207 
211 
212  ref_.add_coord_offset({0, -Shape::kColumn});
213 
214  return *this;
215  }
216 
219  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
220  Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =
221  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag);
222 
224  for (int k = 0; k < Iterations::kColumn; ++k) {
226  for (int m = 0; m < Iterations::kRow; ++m) {
227  dst_ptr[m + k * Iterations::kRow] =
228  *(ref_.data() + ref_.offset({m * Policy::WarpShape::kRow, k}) + pointer_offset / Policy::LaneMmaShape::kM);
229  }
230  }
231  }
234  void load(Fragment &frag) const {
235  load_with_pointer_offset(frag, 0);
236  }
237 
240  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
241 
242  Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =
243  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag);
244 
246  for (int k = 0; k < Iterations::kN; ++k) {
248  for (int m = 0; m < Iterations::kM; ++m) {
249  *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =
250  src_ptr[m + k * Iterations::kM];
251  }
252  }
253  }
254 
257  void store(Fragment const &frag) const {
258  store_with_pointer_offset(frag, 0);
259  }
260 
268  CUTLASS_DEVICE
269  void set_kgroup_index(int k_group) {
270  // no operation here
271  }
272 };
273 
275 
280 template <
282  typename Shape_,
284  typename Element_,
286  typename Policy_,
288  int PartitionsK,
290  int PartitionGroupSize
291 >
292 class MmaSimtTileIterator<Shape_, Operand::kB, Element_, layout::RowMajor, Policy_, PartitionsK, PartitionGroupSize> {
293 public:
294 
296  using Shape = Shape_;
297 
299  static Operand const kOperand = Operand::kB;
300 
302  using Element = Element_;
303 
306 
308  using Policy = Policy_;
309 
312 
314  using Index = typename TensorRef::Index;
315 
317  using LongIndex = typename TensorRef::LongIndex;
318 
321 
322  //
323  // Derived quantities
324  //
325 
326  static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),
327  "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
328 
329  static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
330  static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
331  static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
332  static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
333 
335  using ThreadShape = MatrixShape<
336  Shape::kRow,
337  Shape::kColumn / Policy::WarpShape::kColumn
338  >;
339 
340  static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN),
341  "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
342 
344  using Iterations = MatrixShape<
345  ThreadShape::kRow,
346  ThreadShape::kColumn / Policy::LaneMmaShape::kN
347  >;
348 
350  using Fragment = Array<Element, ThreadShape::kCount>;
351 
352 private:
353 
356 
357 
358 public:
359 
363 
367  TensorRef ref,
368  int lane_id
369  ) {
370 
371  // compute offset based on thread ID and lane layout
372  typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
373 
374  MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
375  MatrixCoord(0, Policy::LaneMmaShape::kN);
376 
377  ref.add_coord_offset(lane_offset);
378 
379  ref_.reset(
380  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(ref.data()),
381  ref.stride(0) / Policy::LaneMmaShape::kN);
382  }
383 
387  ref_.add_pointer_offset(offset);
388  return *this;
389  }
390 
394 
395  ref_.add_coord_offset({
396  coord.row() * Shape::kRow,
397  coord.column() * Shape::kColumn / Policy::LaneMmaShape::kN});
398 
399  return *this;
400  }
401 
405 
406  ref_.add_coord_offset({Shape::kRow, 0});
407 
408  return *this;
409  }
410 
414 
415  ref_.add_coord_offset({-Shape::kRow, 0});
416 
417  return *this;
418  }
419 
422  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
423 
424  Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
425  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
426 
428  for (int k = 0; k < Iterations::kRow; ++k) {
430  for (int n = 0; n < Iterations::kColumn; ++n) {
431  dst_ptr[n + k * Iterations::kColumn] =
432  *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kColumn}) + pointer_offset / Policy::LaneMmaShape::kN);
433  }
434  }
435  }
436 
439  void load(Fragment &frag) const {
440  load_with_pointer_offset(frag, 0);
441  }
442 
445  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
446 
447  Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
448  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
449 
451  for (int k = 0; k < Iterations::kM; ++k) {
453  for (int n = 0; n < Iterations::kN; ++n) {
454  *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =
455  src_ptr[n + k * Iterations::kN];
456  }
457  }
458  }
459 
462  void store(Fragment const &frag, Index pointer_offset) const {
463  store_with_pointer_offset(frag, 0);
464  }
465 
473  CUTLASS_DEVICE
474  void set_kgroup_index(int k_group) {
475  // no operation here
476  }
477 };
478 
480 
485 template <
487  typename Shape_,
489  typename Element_,
491  typename Policy_
492 >
493 class MmaSimtTileIterator<Shape_, Operand::kC, Element_, layout::ColumnMajor, Policy_> {
494 public:
495 
497  using Shape = Shape_;
498 
500  static Operand const kOperand = Operand::kC;
501 
503  using Element = Element_;
504 
507 
509  using Policy = Policy_;
510 
513 
515  using Index = typename TensorRef::Index;
516 
518  using LongIndex = typename TensorRef::LongIndex;
519 
522 
523  //
524  // Derived quantities
525  //
526 
528  (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
529  "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
530 
531  static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
532  static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
533  static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
534  static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
535  static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
536  static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
537 
539  using ThreadShape = MatrixShape<
540  Shape::kRow / Policy::WarpShape::kRow,
541  Shape::kColumn / Policy::WarpShape::kColumn
542  >;
543 
545  (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
546  "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
547 
549  using Iterations = MatrixShape<
550  ThreadShape::kRow / Policy::LaneMmaShape::kM,
551  ThreadShape::kColumn / Policy::LaneMmaShape::kN
552  >;
553 
554  using Delta = MatrixShape<
555  Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
556  Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
557  >;
558 
560  using Fragment = Array<Element, ThreadShape::kCount>;
561 
562 private:
563 
564  TensorRef ref_;
565 
566 public:
567 
571 
575  TensorRef const &ref,
576  int lane_id
577  ):
578  ref_(ref) {
579 
580  // compute offset based on thread ID and lane layout
581  typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
582 
583  MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
584  MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
585 
586  ref_.add_coord_offset(lane_offset);
587  }
588 
592  ref_.add_pointer_offset(offset);
593  return *this;
594  }
595 
599 
600  ref_.add_coord_offset({
601  coord.row() * Shape::kRow,
602  coord.column() * Shape::kColumn});
603 
604  return *this;
605  }
606 
610 
611  ref_.add_coord_offset({Shape::kRow, 0});
612 
613  return *this;
614  }
615 
619 
620  ref_.add_coord_offset({-Shape::kRow, 0});
621 
622  return *this;
623  }
624 
628  Fragment &frag,
629  Index pointer_offset) const {
630 
632  for (int mma_n = 0; mma_n < Iterations::kN; ++mma_n) {
634  for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
635 
636  Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =
637  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> const *>(
638  ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kN + n}));
639 
641  for (int mma_m = 0; mma_m < Iterations::kM; ++mma_m) {
642 
643  Array<Element, Policy::LaneMmaShape::kM> *dst_ptr =
644  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(&frag) +
645  mma_m + Iterations::kM * (n + mma_n * Policy::LaneMmaShape::kN);
646 
647  *dst_ptr = src_ptr[mma_m * Policy::WarpShape::kM];
648  }
649  }
650  }
651  }
652 
655  void load(Fragment &frag) const {
656  load_with_pointer_offset(frag, 0);
657  }
658 
661  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
662 
664  for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
666  for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
667 
668  Array<Element, Policy::LaneMmaShape::kM> *dst_ptr=
669  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> *>(
670  ref_.data() + pointer_offset + ref_.offset({0, mma_n * Delta::kColumn + n}));
671 
673  for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
674 
675  Array<Element, Policy::LaneMmaShape::kM> const *src_ptr =
676  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kM> const *>(&frag) +
677  mma_m + Iterations::kRow * (n + mma_n * Policy::LaneMmaShape::kN);
678 
679  dst_ptr[mma_m * Policy::WarpShape::kRow] = *src_ptr;
680  }
681  }
682  }
683  }
686  void store(Fragment const &frag) const {
687  store_with_pointer_offset(frag, 0);
688  }
689 };
690 
692 
697 template <
699  typename Shape_,
701  typename Element_,
703  typename Policy_
704 >
705 class MmaSimtTileIterator<Shape_, Operand::kC, Element_, layout::RowMajor, Policy_> {
706 public:
707 
709  using Shape = Shape_;
710 
712  static Operand const kOperand = Operand::kC;
713 
715  using Element = Element_;
716 
719 
721  using Policy = Policy_;
722 
725 
727  using Index = typename TensorRef::Index;
728 
730  using LongIndex = typename TensorRef::LongIndex;
731 
734 
735  //
736  // Derived quantities
737  //
738 
740  (!(Shape::kRow % Policy::WarpShape::kRow)) && (!(Shape::kColumn % Policy::WarpShape::kColumn)),
741  "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
742 
743  static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
744  static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
745  static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
746  static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
747  static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
748  static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
749 
751  using ThreadShape = MatrixShape<
752  Shape::kRow / Policy::WarpShape::kRow,
753  Shape::kColumn / Policy::WarpShape::kColumn
754  >;
755 
757  (!(ThreadShape::kRow % Policy::LaneMmaShape::kM)) && (!(ThreadShape::kColumn % Policy::LaneMmaShape::kN)),
758  "Warp-level GEMM shape must be divisible by the arrangement of threads in the warp.");
759 
761  using Iterations = MatrixShape<
762  ThreadShape::kRow / Policy::LaneMmaShape::kM,
763  ThreadShape::kColumn / Policy::LaneMmaShape::kN
764  >;
765 
766  using Delta = MatrixShape<
767  Policy::WarpShape::kRow * Policy::LaneMmaShape::kM,
768  Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN
769  >;
770 
772  using Fragment = Array<Element, ThreadShape::kCount>;
773 
774 private:
775 
776  TensorRef ref_;
777 
778 public:
779 
783 
787  TensorRef const &ref,
788  int lane_id
789  ):
790  ref_(ref) {
791 
792  // compute offset based on thread ID and lane layout
793  typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
794 
795  MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
796  MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
797 
798  ref_.add_coord_offset(lane_offset);
799  }
800 
804  ref_.add_pointer_offset(offset);
805  return *this;
806  }
807 
811 
812  ref_.add_coord_offset({
813  coord.row() * Shape::kRow,
814  coord.column() * Shape::kColumn});
815 
816  return *this;
817  }
818 
822 
823  ref_.add_coord_offset({Shape::kRow, 0});
824 
825  return *this;
826  }
827 
831 
832  ref_.add_coord_offset({-Shape::kRow, 0});
833 
834  return *this;
835  }
836 
840  Fragment &frag,
841  Index pointer_offset) const {
842 
844  for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
846  for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
847 
848  Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
849  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> const *>(
850  ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
851 
853  for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
854 
855  Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
856  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag) +
857  mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
858 
859  *dst_ptr = src_ptr[mma_n * Policy::WarpShape::kColumn];
860  }
861  }
862  }
863  }
864 
867  void load(Fragment &frag) const {
868  load_with_pointer_offset(frag, 0);
869  }
870 
873  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
874 
876  for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
878  for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
879 
880  Array<Element, Policy::LaneMmaShape::kN> *dst_ptr =
881  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(
882  ref_.data() + pointer_offset + ref_.offset({mma_m * Delta::kRow + m, 0}));
883 
885  for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
886 
887  Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
888  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> const *>(&frag) +
889  mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM);
890 
891  dst_ptr[mma_n * Policy::WarpShape::kColumn] = *src_ptr;
892  }
893  }
894  }
895  }
896 
899  void store(Fragment const &frag) const {
900  store_with_pointer_offset(frag, 0);
901  }
902 };
903 
905 
907 
912 template <
914  typename Shape_,
916  typename Element_,
918  typename Policy_,
920  int PartitionsK,
922  int PartitionGroupSize
923 >
924 class MmaSimtTileIterator<Shape_, Operand::kA, Element_, layout::ColumnMajorInterleaved<4>, Policy_, PartitionsK, PartitionGroupSize> {
925 public:
926 
928  using Shape = Shape_;
929 
931  static Operand const kOperand = Operand::kA;
932 
934  using Element = Element_;
935 
938 
940  using Policy = Policy_;
941 
944 
946  using Index = typename TensorRef::Index;
947 
949  using LongIndex = typename TensorRef::LongIndex;
950 
953 
955  static const int kInterleave = 4;
956 
958  static const int kPartitionsK = PartitionsK;
959 
961  static const int kGroupPerTile = PartitionGroupSize / Shape::kColumn;
962 
963  //
964  // Derived quantities
965  //
966 
967  static_assert(!(Shape::kRow % Policy::WarpShape::kRow),
968  "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension.");
969 
970  static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
971  static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
972  static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero.");
973  static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero.");
974 
976  using ThreadShape = MatrixShape<
977  Shape::kRow / Policy::WarpShape::kRow,
978  Shape::kColumn
979  >;
980 
981  static_assert(!(ThreadShape::kRow % Policy::LaneMmaShape::kM) && !(ThreadShape::kColumn % Policy::LaneMmaShape::kK),
982  "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
983 
985  using Iterations = MatrixShape<
986  ThreadShape::kRow / Policy::LaneMmaShape::kM,
987  ThreadShape::kColumn / Policy::LaneMmaShape::kK
988  >;
989 
991  using Fragment = Array<Element, ThreadShape::kCount>;
992 
993 private:
994 
997 
999  int k_group_idx_;
1000 
1001 public:
1004 
1008  TensorRef ref,
1009  int lane_id
1010  ) {
1011 
1012  // compute offset based on thread ID and lane layout
1013  typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
1014 
1015  MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
1016  MatrixCoord(Policy::LaneMmaShape::kM, 0);
1017 
1018  ref.add_coord_offset(lane_offset);
1019 
1020  k_group_idx_ = 0;
1021  ref_.reset(reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *>(ref.data()), ref.stride(0)/Policy::LaneMmaShape::kMK);
1022  }
1023 
1024 
1028  ref_.add_pointer_offset(offset);
1029  return *this;
1030  }
1031 
1035 
1036  ref_.add_coord_offset({
1037  coord.row() * Shape::kRow / Policy::LaneMmaShape::kMK,
1038  coord.column() * Shape::kColumn});
1039 
1040  return *this;
1041  }
1042 
1046 
1047  add_tile_offset({0, 1});
1048 
1049  if (kPartitionsK > 1) {
1050  ++k_group_idx_;
1051  // Jump to next stage
1052  if (k_group_idx_ == kGroupPerTile) {
1053  k_group_idx_ = 0;
1054  add_tile_offset({0, kGroupPerTile * (kPartitionsK-1)});
1055  }
1056  }
1057 
1058  return *this;
1059  }
1060 
1064 
1065  ref_.add_coord_offset({0, -Shape::kColumn});
1066 
1067  return *this;
1068  }
1069 
1072  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
1073 
1074  Array<Element, Policy::LaneMmaShape::kMK > *dst_ptr =
1075  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK> *>(&frag);
1076 
1078  for (int k = 0; k < Iterations::kColumn; ++k) {
1079 
1081  for (int m = 0; m < Iterations::kRow; ++m) {
1082 
1083  dst_ptr[m + k * Iterations::kRow] =
1084  *((ref_.data() + ref_.offset({m * Policy::WarpShape::kRow / kInterleave,
1085  k*Policy::LaneMmaShape::kK}) + pointer_offset / Policy::LaneMmaShape::kM));
1086  }
1087  }
1088  }
1089 
1092  void load(Fragment &frag) const {
1093  load_with_pointer_offset(frag, 0);
1094  }
1095 
1098  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
1099 
1100  Array<Element, Policy::LaneMmaShape::kMK> const *src_ptr =
1101  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kMK > *>(&frag);
1102 
1104  for (int k = 0; k < Iterations::kN; ++k) {
1106  for (int m = 0; m < Iterations::kM; ++m) {
1107  *(ref_.data() + ref_.offset(m * Policy::WarpShape::kM, k) + pointer_offset / Policy::LaneMmaShape::kM) =
1108  src_ptr[m + k * Iterations::kM];
1109  }
1110  }
1111  }
1112 
1115  void store(Fragment const &frag) const {
1116  store_with_pointer_offset(frag, 0);
1117  }
1118 
1126  CUTLASS_DEVICE
1127  void set_kgroup_index(int k_group) {
1128  // no operation here
1129  }
1130 };
1131 
1133 
1138 template <
1140  typename Shape_,
1142  typename Element_,
1144  typename Policy_,
1146  int PartitionsK,
1148  int PartitionGroupSize
1149 >
1150 class MmaSimtTileIterator<Shape_, Operand::kB, Element_, layout::RowMajorInterleaved<4>, Policy_, PartitionsK, PartitionGroupSize> {
1151 public:
1152 
1154  using Shape = Shape_;
1155 
1157  static Operand const kOperand = Operand::kB;
1158 
1160  using Element = Element_;
1161 
1164 
1166  using Policy = Policy_;
1167 
1170 
1172  using Index = typename TensorRef::Index;
1173 
1176 
1179 
1181  static const int kInterleave = 4;
1182 
1184  static const int kPartitionsK = PartitionsK;
1185 
1187  static const int kGroupPerTile = PartitionGroupSize / Shape::kRow;
1188 
1189  //
1190  // Derived quantities
1191  //
1192 
1193  static_assert(!(Shape::kColumn % Policy::WarpShape::kColumn),
1194  "The warp-level GEMM N size must be divisible by the number of threads arranged along the N dimension.");
1195 
1196  static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero.");
1197  static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero.");
1198  static_assert(Policy::WarpShape::kColumn > 0, "Policy::WarpShape::kColumn must be greater than zero.");
1199  static_assert(Shape::kColumn / Policy::WarpShape::kColumn > 0, "Shape::kColumn / Policy::WarpShape::kColumn must be greater than zero.");
1200 
1202  using ThreadShape = MatrixShape<
1203  Shape::kRow,
1204  Shape::kColumn / Policy::WarpShape::kColumn
1205  >;
1206 
1207  static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN) && !(ThreadShape::kRow % Policy::LaneMmaShape::kK),
1208  "Thread-level GEMM must be divisible by Policy::LaneMmaShape.");
1209 
1211  using Iterations = MatrixShape<
1212  ThreadShape::kRow / Policy::LaneMmaShape::kK,
1213  ThreadShape::kColumn / Policy::LaneMmaShape::kN
1214  >;
1215 
1217  using Fragment = Array<Element, ThreadShape::kCount>;
1218 
1219 
1220 private:
1221 
1224 
1226  int k_group_idx_;
1227 
1228 public:
1229 
1233 
1237  TensorRef ref,
1238  int lane_id
1239  ) {
1240 
1241  // compute offset based on thread ID and lane layout
1242  typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
1243 
1244  MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
1245  MatrixCoord(0, Policy::LaneMmaShape::kN);
1246 
1247  ref.add_coord_offset(lane_offset);
1248 
1249  k_group_idx_ = 0;
1250 
1251  ref_.reset(
1252  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *>(ref.data()),
1253  ref.stride(0) / Policy::LaneMmaShape::kKN);
1254  }
1255 
1259  ref_.add_pointer_offset(offset);
1260  return *this;
1261  }
1262 
1266 
1267  ref_.add_coord_offset({
1268  coord.row() * Shape::kRow,
1269  coord.column() * Shape::kColumn / Policy::LaneMmaShape::kKN});
1270 
1271  return *this;
1272  }
1273 
1277 
1278  add_tile_offset({1, 0});
1279 
1280  if (kPartitionsK > 1) {
1281  ++k_group_idx_;
1282  // Jump to next stage
1283  if (k_group_idx_ == kGroupPerTile) {
1284  k_group_idx_ = 0;
1285  add_tile_offset({kGroupPerTile * (kPartitionsK-1), 0});
1286  }
1287  }
1288 
1289  return *this;
1290  }
1291 
1295 
1296  ref_.add_coord_offset({-Shape::kRow, 0});
1297 
1298  return *this;
1299  }
1300 
1303  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const {
1304 
1305  Array<Element, Policy::LaneMmaShape::kKN> *dst_ptr =
1306  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kKN> *>(&frag);
1307 
1309  for (int k = 0; k < Iterations::kRow; ++k) {
1311  for (int n = 0; n < Iterations::kColumn; ++n) {
1312  dst_ptr[n + k * Iterations::kColumn] =
1313  *(ref_.data() + ref_.offset({k * Policy::LaneMmaShape::kK,
1314  n * Policy::WarpShape::kColumn / kInterleave}) + pointer_offset / Policy::LaneMmaShape::kN);
1315  }
1316  }
1317  }
1318 
1321  void load(Fragment &frag) const {
1322  load_with_pointer_offset(frag, 0);
1323  }
1324 
1327  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const {
1328 
1329  Array<Element, Policy::LaneMmaShape::kN> const *src_ptr =
1330  reinterpret_cast<Array<Element, Policy::LaneMmaShape::kN> *>(&frag);
1331 
1333  for (int k = 0; k < Iterations::kM; ++k) {
1335  for (int n = 0; n < Iterations::kN; ++n) {
1336  *(ref_.data() + ref_.offset({k, n * Policy::WarpShape::kN}) + pointer_offset / Policy::LaneMmaShape::kN) =
1337  src_ptr[n + k * Iterations::kN];
1338  }
1339  }
1340  }
1341 
1344  void store(Fragment const &frag, Index pointer_offset) const {
1345  store_with_pointer_offset(frag, 0);
1346  }
1347 
1355  CUTLASS_DEVICE
1356  void set_kgroup_index(int k_group) {
1357  // no operation here
1358  }
1359 };
1360 
1362 
1363 } // namespace warp
1364 } // namespace gemm
1365 } // namespace cutlass
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:257
Describes the lane policy used by warp-level matrix multiply operators targeting SIMT instructions...
Describes the size of a matrix tile.
Definition: matrix_shape.h:42
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread&#39;s part of a tile.
Definition: mma_simt_tile_iterator.h:991
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:404
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:686
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:1007
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:308
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:786
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef const &ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:574
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:730
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:521
Operand
GEMM operand enumeration: D = A * B + C.
Definition: include/cutlass/gemm/gemm.h:39
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:317
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread&#39;s part of a tile.
Definition: mma_simt_tile_iterator.h:1217
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1072
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread&#39;s part of a tile.
Definition: mma_simt_tile_iterator.h:772
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:201
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:1236
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:445
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:362
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:655
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:386
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:234
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:810
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:782
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:162
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:497
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1303
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1098
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
Definition: mma_simt_tile_iterator.h:69
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:439
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:618
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:210
TensorRef< Element, Layout > TensorRef
TensorRef type for loading element from a tensor.
Definition: mma_simt_tile_iterator.h:724
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:1232
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:591
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:219
CUTLASS_HOST_DEVICE Stride stride() const
Returns the layout object&#39;s stride vector.
Definition: tensor_ref.h:277
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1276
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref.h:171
Defines a Shape template for matrix tiles.
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:393
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread&#39;s part of a tile.
Definition: mma_simt_tile_iterator.h:560
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1092
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:183
CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)
Updates only the pointer.
Definition: tensor_ref.h:235
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread&#39;s part of a tile.
Definition: mma_simt_tile_iterator.h:350
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:515
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:1027
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:952
typename TensorRef::LongIndex LongIndex
Long Index type.
Definition: mma_simt_tile_iterator.h:518
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:320
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_HOST_DEVICE MmaSimtTileIterator(TensorRef ref, int lane_id)
Constructor from TensorRef.
Definition: mma_simt_tile_iterator.h:366
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:462
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1063
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:240
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:661
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:509
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory with additional logical offset.
Definition: mma_simt_tile_iterator.h:627
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:1034
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1045
typename Layout::Index Index
Index type.
Definition: tensor_ref.h:165
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:873
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:1258
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:422
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:867
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:733
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:899
CUTLASS_HOST_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const
Loads a fragment from memory with additional logical offset.
Definition: mma_simt_tile_iterator.h:839
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:413
Defines layout functions used by TensorRef and derived classes.
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:296
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:821
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:570
Array< Element, ThreadShape::kCount > Fragment
Fragment object holding a thread&#39;s part of a tile.
Definition: mma_simt_tile_iterator.h:147
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:190
Definition: layout/matrix.h:343
CUTLASS_HOST_DEVICE void load(Fragment &frag) const
Loads a fragment from memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1321
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:1294
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_pointer_offset(LongIndex offset)
Adds a pointer offset to internal pointer(s) to advance through memory.
Definition: mma_simt_tile_iterator.h:803
CUTLASS_HOST_DEVICE MmaSimtTileIterator()
Default ctor constructs null iterator.
Definition: mma_simt_tile_iterator.h:158
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator--()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:830
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:598
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:1178
CUTLASS_HOST_DEVICE void store(Fragment const &frag) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1115
typename TensorRef::Index Index
Index type.
Definition: mma_simt_tile_iterator.h:727
typename TensorRef::TensorCoord TensorCoord
Coordinate for an element in the tensor.
Definition: mma_simt_tile_iterator.h:117
CUTLASS_HOST_DEVICE MmaSimtTileIterator & operator++()
Advances the iterator along the advance dimension.
Definition: mma_simt_tile_iterator.h:609
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:93
CUTLASS_HOST_DEVICE MmaSimtTileIterator & add_tile_offset(TensorCoord const &coord)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: mma_simt_tile_iterator.h:1265
CUTLASS_HOST_DEVICE void store(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1344
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const
Stores a fragment to memory at the location pointed to by the iterator.
Definition: mma_simt_tile_iterator.h:1327
Policy_ Policy
Decomposition of elements among threads.
Definition: mma_simt_tile_iterator.h:721
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_ref.h:168
Shape_ Shape
Shape of tile to load (concept: MatrixShape)
Definition: mma_simt_tile_iterator.h:709
Definition: layout/matrix.h:237