CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
transform/threadblock/predicated_tile_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
35 #pragma once
36 
37 #include "cutlass/arch/memory.h"
39 
41 
42 namespace cutlass {
43 namespace transform {
44 namespace threadblock {
45 
47 
84 // template <typename Iterator>
85 // __global__ void kernel(
86 // typename Iterator::Params params,
87 // typename Iterator::Element *ptr,
88 // TensorCoord extent) {
89 //
90 // typename Iterator::Fragment fragment;
91 //
92 // TensorCoord threadblock_offset(0, 0);
93 //
94 // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
95 //
96 //
97 // fragment = *iter; // load "residue" tile first
98 // ++iter; // advance to first "steady state" tile and update internal masks
99 //
100 //
101 // #pragma unroll
102 // for (int i = Remaining - 1; i >= 0; --i) {
103 //
104 // f(fragment);
105 //
106 // if (!i) {
107 // iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
108 // }
109 //
110 // fragment = *iter; // load tile during "steady state" phase
111 // ++iter; // advance to next tile - lightweight due to steady-state masks
112 // }
113 // }
114 //
115 // void host(TensorView<Element, 2, layout::PitchLinear> view) {
116 //
117 // using Iterator = transform::threadblock::PredicatedTileIterator;
118 //
119 // typename Iterator::Params params(view.layout());
120 //
121 // kernel<Iterator>(params, view.data());
122 // }
125 template <
126  typename Shape,
127  typename Element,
128  typename Layout,
129  int AdvanceRank,
130  typename ThreadMap,
131  int AccessSize = ThreadMap::kElementsPerAccess
132 >
134 
136 
144 template <typename Shape_, typename Element_, int AdvanceRank,
145  typename ThreadMap_, int AccessSize>
146 class PredicatedTileIterator<Shape_, Element_, layout::PitchLinear, AdvanceRank,
147  ThreadMap_, AccessSize> {
148  public:
150  AdvanceRank == 0 || AdvanceRank == 1,
151  "Specialization for pitch-linear iterator may along advance along the "
152  "contiguous(rank=0) or strided(rank=1) dimension.");
153 
154  using Shape = Shape_;
155  using Element = Element_;
156  using Layout = layout::PitchLinear;
157  static int const kAdvanceRank = AdvanceRank;
158  using ThreadMap = ThreadMap_;
159 
160  using Index = typename Layout::Index;
161  using LongIndex = typename Layout::LongIndex;
162 
166 
167  using Pointer = Element *;
169 
172 
174  using TileAccessIterator =
175  PredicatedTileAccessIterator<Shape, Element, Layout, kAdvanceRank,
176  ThreadMap, AccessType>;
177 
178  static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
179 
181  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
182  ThreadMap::kElementsPerAccess>;
183 
185  using Mask = typename TileAccessIterator::Mask;
186 
188  class Params {
189  public:
191 
192  private:
194  typename TileAccessIterator::Params params_;
195 
196  public:
199  Params(Layout const &layout) : params_(layout) { }
200 
202  Params() { }
203  };
204 
205  private:
207  using BytePointer = char *;
208 
209  private:
210  //
211  // Data members
212  //
213 
215  TileAccessIterator address_iterator_;
216 
217  public:
223  Params const &params,
225  Pointer pointer,
227  TensorCoord extent,
229  int thread_id,
231  TensorCoord const &threadblock_offset)
232  : address_iterator_(params.params_, pointer, extent, thread_id,
233  threadblock_offset) {}
234 
238  Params const &params,
239  Pointer pointer,
240  TensorCoord extent,
241  int thread_id
242  )
243  : PredicatedTileIterator(params, pointer, extent, thread_id,
244  make_Coord(0, 0)) {}
245 
248  void add_pointer_offset(LongIndex pointer_offset) {
249  address_iterator_.add_pointer_offset(pointer_offset);
250  }
251 
260  if (kAdvanceRank)
261  address_iterator_.add_tile_offset({0, 1});
262  else
263  address_iterator_.add_tile_offset({1, 0});
264 
265  return *this;
266  }
267 
276  PredicatedTileIterator self(*this);
277  operator++();
278  return self;
279  }
280 
283  void clear_mask() { address_iterator_.clear_mask(); }
284 
287  void enable_mask() { address_iterator_.enable_mask(); }
288 
291  void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
292 
295  void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
296 
297  CUTLASS_DEVICE
298  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
299 
300  AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
301 
303  for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
305  for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
306 
308  for (int v = 0; v < kAccessesPerVector; ++v) {
309 
310  int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
311 
312  address_iterator_.set_iteration_index(idx);
313  auto ptr = (address_iterator_.get() + pointer_offset);
314 
315  if (address_iterator_.valid()) {
316  frag_ptr[idx] = *ptr;
317  }
318  ++address_iterator_;
319  }
320  }
321  }
322  }
323 
325  CUTLASS_DEVICE
326  void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
327 
329  CUTLASS_DEVICE
330  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
331  address_iterator_.set_iteration_index(0);
332  AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
333 
335  for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
337  for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
339  for (int v = 0; v < kAccessesPerVector; ++v) {
340 
341  int idx = v + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous);
342 
343  if (address_iterator_.valid()) {
344  *(address_iterator_.get() + pointer_offset) = frag_ptr[idx];
345  }
346  ++address_iterator_;
347  }
348  }
349  }
350  }
351 
353  CUTLASS_DEVICE
354  void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
355 };
356 
358 
366 template <
367  typename Shape_,
368  typename Element_,
369  int AdvanceRank,
370  typename ThreadMap_,
371  int AccessSize
372 >
373 class PredicatedTileIterator<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, AccessSize> {
374 public:
375 
376  static_assert(AdvanceRank == 0 || AdvanceRank == 1,
377  "Specialization for pitch-linear iterator may along advance along the "
378  "contiguous(rank=0) or strided(rank=1) dimension.");
379 
380  using Shape = Shape_;
381  using Element = Element_;
382  using Layout = layout::ColumnMajor;
383  static int const kAdvanceRank = AdvanceRank;
384  using ThreadMap = ThreadMap_;
385 
386  using Index = typename Layout::Index;
387  using LongIndex = typename Layout::LongIndex;
388 
392 
393  using Pointer = Element *;
395 
398  Element,
400  (kAdvanceRank == 0 ? 0 : 1),
401  ThreadMap,
402  AccessSize
403  >;
404 
405  using AccessType = typename UnderlyingIterator::AccessType;
406 
408  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
409 
411  using Mask = typename UnderlyingIterator::Mask;
412 
414  class Params {
415  private:
416 
417  friend PredicatedTileIterator;
418 
420  typename UnderlyingIterator::Params params_;
421 
422  public:
423 
425  Params() { }
426 
429  Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
430 
431  }
432  };
433 
434 
435 private:
436 
437  //
438  // Data members
439  //
440 
442  UnderlyingIterator iterator_;
443 
444 public:
445 
449  Params const &params,
450  Pointer pointer,
451  TensorCoord extent,
452  int thread_id,
453  TensorCoord const &threadblock_offset
454  ):
455  iterator_(
456  params.params_,
457  pointer,
458  layout::PitchLinearCoord(extent.row(), extent.column()),
459  thread_id,
460  layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
461  ) { }
462 
466  Params const &params,
467  Pointer pointer,
468  TensorCoord extent,
469  int thread_id
470  ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
471 
474  void add_pointer_offset(LongIndex pointer_offset) {
475  iterator_.add_pointer_offset(pointer_offset);
476  }
477 
485  ++iterator_;
486  return *this;
487  }
488 
496  PredicatedTileIterator self(*this);
497  operator++();
498  return self;
499  }
500 
503  void clear_mask() {
504  iterator_.clear_mask();
505  }
506 
509  void enable_mask() {
510  iterator_.enable_mask();
511  }
512 
515  void set_mask(Mask const &mask) {
516  iterator_.set_mask(mask);
517  }
518 
521  void get_mask(Mask &mask) {
522  iterator_.get_mask(mask);
523  }
524 
526  CUTLASS_DEVICE
527  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
528  iterator_.load_with_pointer_offset(frag, pointer_offset);
529  }
530 
532  CUTLASS_DEVICE
533  void load(Fragment &frag) {
534  load_with_pointer_offset(frag, 0);
535  }
536 
538  CUTLASS_DEVICE
539  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
540  iterator_.store_with_pointer_offset(frag, pointer_offset);
541  }
542 
544  CUTLASS_DEVICE
545  void store(Fragment const &frag) {
546  store_with_pointer_offset(frag, 0);
547  }
548 };
549 
551 
559 template <
560  typename Shape_,
561  typename Element_,
562  int AdvanceRank,
563  typename ThreadMap_,
564  int AccessSize
565 >
566 class PredicatedTileIterator<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, AccessSize> {
567 public:
568 
569  static_assert(AdvanceRank == 0 || AdvanceRank == 1,
570  "Specialization for pitch-linear iterator may along advance along the "
571  "contiguous(rank=0) or strided(rank=1) dimension.");
572 
573  using Shape = Shape_;
574  using Element = Element_;
575  using Layout = layout::RowMajor;
576  static int const kAdvanceRank = AdvanceRank;
577  using ThreadMap = ThreadMap_;
578 
579  using Index = typename Layout::Index;
580  using LongIndex = typename Layout::LongIndex;
581 
585 
586  using Pointer = Element *;
588 
591  Element,
593  (kAdvanceRank == 0 ? 1 : 0),
594  ThreadMap,
595  AccessSize
596  >;
597 
598  using AccessType = typename UnderlyingIterator::AccessType;
599 
601  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>;
602 
604  using Mask = typename UnderlyingIterator::Mask;
605 
607  class Params {
608  private:
609 
610  friend PredicatedTileIterator;
611 
613  typename UnderlyingIterator::Params params_;
614 
615  public:
616 
618  Params() { }
619 
622  Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
623 
624  };
625  };
626 
627 
628 private:
629 
630  //
631  // Data members
632  //
633 
635  UnderlyingIterator iterator_;
636 
637 public:
638 
642  Params const &params,
643  Pointer pointer,
644  TensorCoord extent,
645  int thread_id,
646  TensorCoord const &threadblock_offset
647  ):
648  iterator_(
649  params.params_,
650  pointer,
651  layout::PitchLinearCoord(extent.column(), extent.row()),
652  thread_id,
653  layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
654  ) { }
655 
659  Params const &params,
660  Pointer pointer,
661  TensorCoord extent,
662  int thread_id
663  ): PredicatedTileIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
664 
667  void add_pointer_offset(LongIndex pointer_offset) {
668  iterator_.add_pointer_offset(pointer_offset);
669  }
670 
678  ++iterator_;
679  return *this;
680  }
681 
689  PredicatedTileIterator self(*this);
690  operator++();
691  return self;
692  }
693 
696  void clear_mask() {
697  iterator_.clear_mask();
698  }
699 
702  void enable_mask() {
703  iterator_.enable_mask();
704  }
705 
708  void set_mask(Mask const &mask) {
709  iterator_.set_mask(mask);
710  }
711 
714  void get_mask(Mask &mask) {
715  iterator_.get_mask(mask);
716  }
717 
719  CUTLASS_DEVICE
720  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
721  iterator_.load_with_pointer_offset(frag, pointer_offset);
722  }
723 
725  CUTLASS_DEVICE
726  void load(Fragment &frag) {
727  load_with_pointer_offset(frag, 0);
728  }
729 
731  CUTLASS_DEVICE
732  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
733  iterator_.store_with_pointer_offset(frag, pointer_offset);
734  }
735 
737  CUTLASS_DEVICE
738  void store(Fragment const &frag) {
739  store_with_pointer_offset(frag, 0);
740  }
741 };
742 
744 
753 
754 template <typename Shape_, typename Element_, int AdvanceRank,
755  typename ThreadMap_, int AccessSize, int InterleavedK>
756 class PredicatedTileIterator<Shape_, Element_,
757  layout::ColumnMajorInterleaved<InterleavedK>,
758  AdvanceRank, ThreadMap_, AccessSize> {
759  public:
761  AdvanceRank == 0 || AdvanceRank == 1,
762  "Specialization for pitch-linear iterator may along advance along the "
763  "contiguous(rank=0) or strided(rank=1) dimension.");
764 
765  using Shape = Shape_;
766  using Element = Element_;
767  static int const kInterleavedK = InterleavedK;
769  static int const kAdvanceRank = AdvanceRank;
770  using ThreadMap = ThreadMap_;
771 
772  using Index = typename Layout::Index;
773  using LongIndex = typename Layout::LongIndex;
774 
778 
779  using Pointer = Element *;
781 
783  layout::PitchLinearShape<Shape::kRow * kInterleavedK,
784  Shape::kColumn / kInterleavedK>,
785  Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessSize>;
786 
787 
788  using AccessType = typename UnderlyingIterator::AccessType;
789 
791  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
792  ThreadMap::kElementsPerAccess>;
793 
795  using Mask = typename UnderlyingIterator::Mask;
796 
798  class Params {
799  private:
800  friend PredicatedTileIterator;
801 
803  typename UnderlyingIterator::Params params_;
804 
805  public:
807  Params() {}
808 
811  Params(Layout const &layout)
812  : params_(layout::PitchLinear(layout.stride(0))) {}
813  };
814 
815  private:
816  //
817  // Data members
818  //
819 
821  UnderlyingIterator iterator_;
822 
823  public:
829  Params const &params,
831  Pointer pointer,
833  TensorCoord extent,
835  int thread_id,
837  TensorCoord const &threadblock_offset)
838  : iterator_(params.params_, pointer,
839  layout::PitchLinearCoord(extent.row() * kInterleavedK,
840  extent.column() / kInterleavedK),
841  thread_id,
842  layout::PitchLinearCoord(
843  threadblock_offset.row() * kInterleavedK,
844  threadblock_offset.column() / kInterleavedK)) {}
845 
849  Params const &params,
850  Pointer pointer,
851  TensorCoord extent,
852  int thread_id
853  )
854  : PredicatedTileIterator(params, pointer, extent, thread_id,
855  make_Coord(0, 0)) {}
856 
859  void add_pointer_offset(LongIndex pointer_offset) {
860  iterator_.add_pointer_offset(pointer_offset);
861  }
862 
871  ++iterator_;
872  return *this;
873  }
874 
883  PredicatedTileIterator self(*this);
884  operator++();
885  return self;
886  }
887 
890  void clear_mask() { iterator_.clear_mask(); }
891 
894  void enable_mask() { iterator_.enable_mask(); }
895 
898  void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
899 
902  void get_mask(Mask &mask) { iterator_.get_mask(mask); }
903 
905  CUTLASS_DEVICE
906  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
907  iterator_.load_with_pointer_offset(frag, pointer_offset);
908  }
909 
911  CUTLASS_DEVICE
912  void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
913 
915  CUTLASS_DEVICE
916  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
917  iterator_.store_with_pointer_offset(frag, pointer_offset);
918  }
919 
921  CUTLASS_DEVICE
922  void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
923 };
924 
926 
935 template <typename Shape_, typename Element_, int AdvanceRank,
936  typename ThreadMap_, int AccessSize, int InterleavedK>
937 class PredicatedTileIterator<Shape_, Element_,
938  layout::RowMajorInterleaved<InterleavedK>,
939  AdvanceRank, ThreadMap_, AccessSize> {
940  public:
942  AdvanceRank == 0 || AdvanceRank == 1,
943  "Specialization for pitch-linear iterator may along advance along the "
944  "contiguous(rank=0) or strided(rank=1) dimension.");
945 
946  using Shape = Shape_;
947  using Element = Element_;
948  static int const kInterleavedK = InterleavedK;
950  static int const kAdvanceRank = AdvanceRank;
951  using ThreadMap = ThreadMap_;
952 
953  using Index = typename Layout::Index;
954  using LongIndex = typename Layout::LongIndex;
955 
959 
960  using Pointer = Element *;
962 
964  layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
965  Shape::kRow / kInterleavedK>,
966  Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessSize>;
967 
968 
969  using AccessType = typename UnderlyingIterator::AccessType;
970 
972  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
973  ThreadMap::kElementsPerAccess>;
974 
976  using Mask = typename UnderlyingIterator::Mask;
977 
979  class Params {
980  private:
981  friend PredicatedTileIterator;
982 
984  typename UnderlyingIterator::Params params_;
985 
986  public:
988  Params() {}
989 
992  Params(Layout const &layout)
993  : params_(layout::PitchLinear(layout.stride(0))) {}
994  };
995 
996  private:
997  //
998  // Data members
999  //
1000 
1002  UnderlyingIterator iterator_;
1003 
1004  public:
1010  Params const &params,
1012  Pointer pointer,
1014  TensorCoord extent,
1016  int thread_id,
1018  TensorCoord const &threadblock_offset)
1019  : iterator_(params.params_, pointer,
1020  layout::PitchLinearCoord(extent.column() * kInterleavedK,
1021  extent.row() / kInterleavedK),
1022  thread_id,
1023  layout::PitchLinearCoord(
1024  threadblock_offset.column() * kInterleavedK,
1025  threadblock_offset.row() / kInterleavedK)) {}
1026 
1030  Params const &params,
1031  Pointer pointer,
1032  TensorCoord extent,
1033  int thread_id
1034  )
1035  : PredicatedTileIterator(params, pointer, extent, thread_id,
1036  make_Coord(0, 0)) {}
1037 
1040  void add_pointer_offset(LongIndex pointer_offset) {
1041  iterator_.add_pointer_offset(pointer_offset);
1042  }
1043 
1052  ++iterator_;
1053  return *this;
1054  }
1055 
1064  PredicatedTileIterator self(*this);
1065  operator++();
1066  return self;
1067  }
1068 
1071  void clear_mask() { iterator_.clear_mask(); }
1072 
1075  void enable_mask() { iterator_.enable_mask(); }
1076 
1079  void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1080 
1083  void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1084 
1086  CUTLASS_DEVICE
1087  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
1088  iterator_.load_with_pointer_offset(frag, pointer_offset);
1089  }
1090 
1092  CUTLASS_DEVICE
1093  void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
1094 
1096  CUTLASS_DEVICE
1097  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
1098  iterator_.store_with_pointer_offset(frag, pointer_offset);
1099  }
1100 
1102  CUTLASS_DEVICE
1103  void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
1104 };
1105 
1107 
1108 } // namespace threadblock
1109 } // namespace transform
1110 } // namespace cutlass
1111 
typename UnderlyingIterator::AccessType AccessType
Definition: transform/threadblock/predicated_tile_iterator.h:405
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:503
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:533
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:355
CUTLASS_DEVICE void store(Fragment const &frag)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:354
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: transform/threadblock/predicated_tile_iterator.h:1040
typename TileAccessIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: transform/threadblock/predicated_tile_iterator.h:185
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Definition: transform/threadblock/predicated_tile_iterator.h:870
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Architecture-specific operators on memory.
CUTLASS_HOST_DEVICE PredicatedTileIterator operator++(int)
Definition: transform/threadblock/predicated_tile_iterator.h:495
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: transform/threadblock/predicated_tile_iterator.h:515
T type
Definition: platform.h:351
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileIterator with zero threadblock offset.
Definition: transform/threadblock/predicated_tile_iterator.h:465
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:249
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:352
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:696
typename Layout::Index Index
Definition: transform/threadblock/predicated_tile_iterator.h:579
typename Layout::TensorCoord TensorCoord
Definition: transform/threadblock/predicated_tile_iterator.h:584
typename Layout::LongIndex LongIndex
Definition: transform/threadblock/predicated_tile_iterator.h:580
Aligned array type.
Definition: array.h:511
CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:916
typename platform::remove_const< Element >::type * NonConstPointer
Definition: transform/threadblock/predicated_tile_iterator.h:394
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileIterator with zero threadblock offset.
Definition: transform/threadblock/predicated_tile_iterator.h:1029
CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:732
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:246
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID...
Definition: transform/threadblock/predicated_tile_iterator.h:641
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: transform/threadblock/predicated_tile_iterator.h:859
CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
Definition: transform/threadblock/predicated_tile_iterator.h:298
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: transform/threadblock/predicated_tile_iterator.h:1008
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:912
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:1071
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Definition: transform/threadblock/predicated_tile_iterator.h:484
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileIterator with zero threadblock offset.
Definition: transform/threadblock/predicated_tile_iterator.h:848
CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:1097
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: transform/threadblock/predicated_tile_iterator.h:221
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: transform/threadblock/predicated_tile_iterator.h:1079
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::kElementsPerAccess > Fragment
Fragment object to be loaded or stored.
Definition: transform/threadblock/predicated_tile_iterator.h:408
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: transform/threadblock/predicated_tile_iterator.h:811
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: transform/threadblock/predicated_tile_iterator.h:714
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::kElementsPerAccess > Fragment
Fragment object to be loaded or stored.
Definition: transform/threadblock/predicated_tile_iterator.h:792
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: transform/threadblock/predicated_tile_iterator.h:248
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: transform/threadblock/predicated_tile_iterator.h:411
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: transform/threadblock/predicated_tile_iterator.h:827
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: transform/threadblock/predicated_tile_iterator.h:521
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:1093
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: transform/threadblock/predicated_tile_iterator.h:429
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: transform/threadblock/predicated_tile_iterator.h:295
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: transform/threadblock/predicated_tile_iterator.h:604
CUTLASS_HOST_DEVICE PredicatedTileIterator operator++(int)
Definition: transform/threadblock/predicated_tile_iterator.h:688
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: transform/threadblock/predicated_tile_iterator.h:199
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: transform/threadblock/predicated_tile_iterator.h:708
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: transform/threadblock/predicated_tile_iterator.h:1083
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Constructs a TileIterator from its precomputed state, threadblock offset, and thread ID...
Definition: transform/threadblock/predicated_tile_iterator.h:448
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::kElementsPerAccess > Fragment
Fragment object to be loaded or stored.
Definition: transform/threadblock/predicated_tile_iterator.h:182
typename UnderlyingIterator::AccessType AccessType
Definition: transform/threadblock/predicated_tile_iterator.h:598
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: transform/threadblock/predicated_tile_iterator.h:667
CUTLASS_HOST_DEVICE PredicatedTileIterator operator++(int)
Definition: transform/threadblock/predicated_tile_iterator.h:882
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Definition: transform/threadblock/predicated_tile_iterator.h:677
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:509
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileIterator with zero threadblock offset.
Definition: transform/threadblock/predicated_tile_iterator.h:658
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:326
CUTLASS_DEVICE void store(Fragment const &frag)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:738
typename Layout::LongIndex LongIndex
Definition: transform/threadblock/predicated_tile_iterator.h:161
typename Layout::TensorCoord TensorCoord
Definition: transform/threadblock/predicated_tile_iterator.h:391
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::kElementsPerAccess > Fragment
Fragment object to be loaded or stored.
Definition: transform/threadblock/predicated_tile_iterator.h:973
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Definition: transform/threadblock/predicated_tile_iterator.h:1051
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:1075
Templates calculating the address and predicates to the load of tiles from pitch-linear rank=2 tensor...
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Definition: transform/threadblock/predicated_tile_iterator.h:259
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:283
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:702
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
CUTLASS_HOST_DEVICE PredicatedTileIterator operator++(int)
Definition: transform/threadblock/predicated_tile_iterator.h:1063
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: transform/threadblock/predicated_tile_iterator.h:902
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::kElementsPerAccess > Fragment
Fragment object to be loaded or stored.
Definition: transform/threadblock/predicated_tile_iterator.h:601
typename platform::remove_const< Element >::type * NonConstPointer
Definition: transform/threadblock/predicated_tile_iterator.h:780
CUTLASS_DEVICE void store(Fragment const &frag)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:545
Definition: transform/threadblock/predicated_tile_iterator.h:133
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: transform/threadblock/predicated_tile_iterator.h:622
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:890
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_DEVICE void store(Fragment const &frag)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:922
CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:720
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: transform/threadblock/predicated_tile_iterator.h:898
CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:1087
Definition: layout/matrix.h:343
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
typename platform::remove_const< Element >::type * NonConstPointer
Definition: transform/threadblock/predicated_tile_iterator.h:961
typename Layout::LongIndex LongIndex
Definition: transform/threadblock/predicated_tile_iterator.h:387
CUTLASS_HOST_DEVICE PredicatedTileIterator operator++(int)
Definition: transform/threadblock/predicated_tile_iterator.h:275
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:726
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: transform/threadblock/predicated_tile_iterator.h:795
CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:330
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: transform/threadblock/predicated_tile_iterator.h:992
CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:906
typename platform::remove_const< Element >::type * NonConstPointer
Definition: transform/threadblock/predicated_tile_iterator.h:168
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: transform/threadblock/predicated_tile_iterator.h:474
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:287
typename UnderlyingIterator::AccessType AccessType
Definition: transform/threadblock/predicated_tile_iterator.h:969
CUTLASS_HOST_DEVICE PredicatedTileIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileIterator with zero threadblock offset.
Definition: transform/threadblock/predicated_tile_iterator.h:237
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: transform/threadblock/predicated_tile_iterator.h:291
CUTLASS_DEVICE void store(Fragment const &frag)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:1103
Definition: matrix_coord.h:39
typename Layout::TensorCoord TensorCoord
Definition: transform/threadblock/predicated_tile_iterator.h:165
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: transform/threadblock/predicated_tile_iterator.h:976
CUTLASS_DEVICE void store_with_pointer_offset(Fragment const &frag, Index pointer_offset)
Store a fragment to memory.
Definition: transform/threadblock/predicated_tile_iterator.h:539
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: transform/threadblock/predicated_tile_iterator.h:894
CUTLASS_DEVICE void load_with_pointer_offset(Fragment &frag, Index pointer_offset)
Loads a fragment from memory.
Definition: transform/threadblock/predicated_tile_iterator.h:527
typename platform::remove_const< Element >::type * NonConstPointer
Definition: transform/threadblock/predicated_tile_iterator.h:587
Definition: layout/matrix.h:237