CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
predicated_tile_access_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  *modification, are permitted provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice,
7  *this list of conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright
9  *notice, this list of conditions and the following disclaimer in the
10  *documentation and/or other materials provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its
12  *contributors may be used to endorse or promote products derived from this
13  *software without specific prior written permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16  *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17  *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18  *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
19  *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
20  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
21  *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
22  *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
23  *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
24  *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25  *
26  **************************************************************************************************/
40 #pragma once
41 
42 #include "cutlass/array.h"
43 #include "cutlass/coord.h"
44 #include "cutlass/cutlass.h"
45 #include "cutlass/layout/matrix.h"
47 #include "cutlass/matrix_shape.h"
49 #include "cutlass/tensor_ref.h"
50 #include "cutlass/tensor_view.h"
51 
53 
55 
56 namespace cutlass {
57 namespace transform {
58 namespace threadblock {
59 
61 
64 template <typename Shape, typename Element, typename Layout, int AdvanceRank,
65  typename ThreadMap, typename AccessType>
67 
69 
72 template <typename Shape_, typename Element_, int AdvanceRank,
73  typename ThreadMap_, typename AccessType_>
74 class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
75  AdvanceRank, ThreadMap_, AccessType_> {
76  public:
78  AdvanceRank == 0 || AdvanceRank == 1,
79  "Specialization for pitch-linear iterator may along advance along the "
80  "contiguous(rank=0) or strided(rank=1) dimension.");
81 
82  using Shape = Shape_;
83  using Element = Element_;
84  using Layout = layout::PitchLinear;
85  static int const kAdvanceRank = AdvanceRank;
86  using ThreadMap = ThreadMap_;
87  using AccessType = AccessType_;
88 
89  using Index = typename Layout::Index;
90  using LongIndex = typename Layout::LongIndex;
91 
94  using TensorCoord = typename Layout::TensorCoord;
95 
96  using Pointer = Element *;
98 
99  static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements;
100 
101  static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
102  "Vectors implied by the thread map must be divisible by the access type.");
103 
104  static int const kPredicatesPerByte = 4;
105  static int const kPredicatesPerWord = 4 * kPredicatesPerByte;
106 
107  static int const kPredicateCount = ThreadMap::Iterations::kCount * kAccessesPerVector;
108 
110  static int const kPredicateByteCount =
111  (kPredicateCount + kPredicatesPerByte - 1) / kPredicatesPerByte;
112  static int const kPredicateWordCount = (kPredicateByteCount + 3) / 4;
113 
114  static unsigned const kPredicateMask = (1u << kPredicatesPerByte) - 1u;
115 
116  static_assert(kPredicateWordCount <= 4, "Too many predicates.");
117 
119  using Mask = Array<uint32_t, kPredicateWordCount>;
120 
122  class Params {
123  public:
125 
126  private:
128  int stride_;
131  int inc_strided_;
134  int inc_next_;
137  int inc_advance_;
138 
139  public:
140 
141  // Default ctor
143  Params(): stride_(0), inc_strided_(0), inc_next_(0), inc_advance_(0) { }
144 
147  Params(Layout const &layout) : stride_(layout.stride(0)) {
148  inc_strided_ = (stride_ * ThreadMap::Delta::kStrided) *
150 
151  if (kAdvanceRank) {
152  // advance along strided dimension
153  inc_advance_ =
154  Shape::kStrided * stride_ * sizeof_bits<Element>::value / 8;
155  } else {
156  // advance along contiguous dimension
157  inc_advance_ = Shape::kContiguous * sizeof_bits<Element>::value / 8;
158  }
159 
160  inc_next_ = inc_advance_ - (ThreadMap::Iterations::kStrided - 1) *
161  ThreadMap::Delta::kStrided * stride_ *
163  };
164  };
165 
166  private:
168  using BytePointer = char *;
169 
170  private:
171  //
172  // Data members
173  //
174 
176  Params const &params_;
177 
179  BytePointer pointer_;
180 
182  uint32_t predicates_[kPredicateWordCount];
183 
185  TensorCoord extent_;
186 
188  TensorCoord thread_offset_;
189 
191  TensorCoord residue_offset_;
192 
194  bool is_residue_tile_;
195 
197  int iteration_vector_;
198 
200  int iteration_contiguous_;
201 
203  int iteration_strided_;
204 
205  private:
207  CUTLASS_DEVICE
208  void compute_predicates_(
210  TensorCoord extent,
212  bool is_steady_state = false) {
213 
215  for (int i = 0; i < kPredicateWordCount; ++i) {
216  predicates_[i] = 0u;
217  }
218 
219  for (int access_idx = 0; access_idx < ThreadMap::Iterations::kCount * kAccessesPerVector; ++access_idx) {
220 
221  int s = access_idx / (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
222 
223  int access_residual = access_idx % (ThreadMap::Iterations::kContiguous * kAccessesPerVector);
224 
225  int c = access_residual / kAccessesPerVector;
226  int v = access_residual % kAccessesPerVector;
227 
228  TensorCoord iteration_coord(c * ThreadMap::Delta::kContiguous + v * AccessType::kElements,
229  s * ThreadMap::Delta::kStrided);
230 
231  TensorCoord coord = thread_offset_ + iteration_coord;
232 
233  bool guard;
234 
235  if (is_steady_state) {
236  if (kAdvanceRank == 0) {
237  guard = (coord.strided() < extent.strided());
238  } else {
239  guard = (coord.contiguous() < extent.contiguous());
240  }
241  } else {
242  guard = (coord.strided() < extent.strided() &&
243  coord.contiguous() < extent.contiguous());
244  }
245 
246  int pred_idx = v + kAccessesPerVector * (c + ThreadMap::Iterations::kContiguous * s);
247 
248  int word_idx = pred_idx / kPredicatesPerWord;
249  int residual = pred_idx % kPredicatesPerWord;
250  int byte_idx = residual / kPredicatesPerByte;
251  int bit_idx = residual % kPredicatesPerByte;
252 
253  predicates_[word_idx] |= (unsigned(guard) << (byte_idx * 8 + bit_idx));
254 
255  }
256 
257  }
258 
259  public:
265  Params const &params,
267  Pointer pointer,
269  TensorCoord extent,
271  int thread_id,
273  TensorCoord const &threadblock_offset)
274  : params_(params),
275  pointer_(reinterpret_cast<BytePointer>(
276  const_cast<NonConstPointer>(pointer))),
277  extent_(extent),
278  is_residue_tile_(true) {
279 
280  TensorCoord residue_extent;
281  if (kAdvanceRank) {
282 
283  Index residue_size = (extent_[kAdvanceRank] % Shape::kStrided);
284  if (!residue_size) {
285  residue_size = Shape::kStrided;
286  }
287 
288  residue_offset_ = make_Coord(0, residue_size);
289  residue_extent = make_Coord(
290  extent_.contiguous(),
291  min(threadblock_offset.strided() + residue_offset_.strided(), extent_.strided())
292  );
293 
294  } else {
295 
296  Index residue_size = (extent_[kAdvanceRank] % Shape::kContiguous);
297  if (!residue_size) {
298  residue_size = Shape::kContiguous;
299  }
300  residue_offset_ = make_Coord(residue_size, 0);
301  residue_extent = make_Coord(
302  min(extent_.contiguous(), threadblock_offset.contiguous() + residue_offset_.contiguous()),
303  extent_.strided()
304  );
305  }
306 
307  // Per-thread offset in logical coordinates of tensor
308  thread_offset_ = threadblock_offset + ThreadMap::initial_offset(thread_id);
309 
310  // update internal pointers
311  Layout layout(params_.stride_);
312  add_pointer_offset(layout(thread_offset_));
313 
314  compute_predicates_(residue_extent, false);
315 
316  set_iteration_index(0);
317  }
318 
323  Params const &params,
325  Pointer pointer,
327  TensorCoord extent,
329  int thread_id)
330  : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
331  make_Coord(0, 0)) {}
332 
335  void set_iteration_index(int index) {
336 
337  iteration_vector_ = index % kAccessesPerVector;
338  int residual_access = index / kAccessesPerVector;
339 
340  iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous;
341  iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous;
342 
343  }
344 
347  void add_pointer_offset(LongIndex pointer_offset) {
348  pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
349  }
350 
352  CUTLASS_DEVICE
354  TensorCoord const &tile_offset) {
355  if (is_residue_tile_) {
356 
357  thread_offset_ += residue_offset_;
358 
359  Layout layout(params_.stride_);
360  add_pointer_offset(layout(residue_offset_));
361 
362  compute_predicates_(extent_, true);
363 
364  if (kAdvanceRank) {
365  pointer_ += params_.inc_advance_ * (tile_offset.strided() - 1);
366  pointer_ += Shape::kContiguous * tile_offset.contiguous();
367  } else {
368  pointer_ += params_.inc_advance_ * (tile_offset.contiguous() - 1);
369  pointer_ += Shape::kStrided * tile_offset.strided();
370  }
371  } else {
372  if (kAdvanceRank) {
373  pointer_ += params_.inc_advance_ * tile_offset.strided();
374  pointer_ += Shape::kContiguous * tile_offset.contiguous();
375  } else {
376  pointer_ += params_.inc_advance_ * tile_offset.contiguous();
377  pointer_ += Shape::kStrided * tile_offset.strided();
378  }
379  }
380  is_residue_tile_ = false;
381  }
382 
385  AccessType *get() const {
386  return reinterpret_cast<AccessType *>(
387  pointer_ +
388  iteration_contiguous_ * (ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value) / 8) + iteration_vector_;
389  }
390 
394 
395  ++iteration_vector_;
396  if (iteration_vector_ < kAccessesPerVector) {
397  return *this;
398  }
399 
400  iteration_vector_ = 0;
401  ++iteration_contiguous_;
402 
403  if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) {
404  return *this;
405  }
406 
407  // Enter here only if (iteration_contiguous_ ==
408  // ThreadMap::Iteration::kContiguous)
409  iteration_contiguous_ = 0;
410  ++iteration_strided_;
411 
412  if (iteration_strided_ < ThreadMap::Iterations::kStrided) {
413  pointer_ += params_.inc_strided_;
414  return *this;
415  }
416 
417  // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
418  // which means we enter the next tile.
419  iteration_strided_ = 0;
420 
421  // advance to next tile
422  pointer_ += params_.inc_next_;
423 
424  // now return to start tile - if the iterator is subsequently advanced, this
425  // subtraction as well as the subsequent integer addition are both elided by
426  // the compiler.
427  pointer_ -= params_.inc_advance_;
428 
429  return *this;
430  }
431 
435  PredicatedTileAccessIterator self(*this);
436  operator++();
437  return self;
438  }
439 
442  void clear_mask() {
444  for (int i = 0; i < kPredicateWordCount; ++i) {
445  predicates_[i] = 0u;
446  }
447 
448  }
449 
452  void enable_mask() {
454  for (int i = 0; i < kPredicateWordCount; ++i) {
455  predicates_[i] = 0xffffffff;
456  }
457  }
458 
461  void set_mask(Mask const &mask) {
463  for (int i = 0; i < kPredicateWordCount; ++i) {
464  predicates_[i] = mask[i];
465  }
466 
467  }
468 
471  void get_mask(Mask &mask) {
473  for (int i = 0; i < kPredicateWordCount; ++i) {
474  mask[i] = predicates_[i];
475  }
476  }
477 
480  bool valid() {
481 
482 
483  int pred_idx =
484  iteration_vector_ + kAccessesPerVector * (iteration_contiguous_ + iteration_strided_ * ThreadMap::Iterations::kContiguous);
485 
486  int word_idx = pred_idx / kPredicatesPerWord;
487  int residual = pred_idx % kPredicatesPerWord;
488  int byte_idx = residual / kPredicatesPerByte;
489  int bit_idx = residual % kPredicatesPerByte;
490 
491  bool pred = (predicates_[word_idx] & (1u << (byte_idx * 8 + bit_idx))) != 0;
492  return pred;
493 
494 
495  //return true;
496  }
497 };
498 
500 
508 template <typename Shape_, typename Element_, int AdvanceRank,
509  typename ThreadMap_, typename AccessType_>
510 class PredicatedTileAccessIterator<Shape_, Element_, layout::ColumnMajor,
511  AdvanceRank, ThreadMap_, AccessType_> {
512  public:
514  AdvanceRank == 0 || AdvanceRank == 1,
515  "Specialization for pitch-linear iterator may along advance along the "
516  "contiguous(rank=0) or strided(rank=1) dimension.");
517 
518  using Shape = Shape_;
519  using Element = Element_;
520  using Layout = layout::ColumnMajor;
521  static int const kAdvanceRank = AdvanceRank;
522  using ThreadMap = ThreadMap_;
523  using AccessType = AccessType_;
524 
525  using Index = typename Layout::Index;
526  using LongIndex = typename Layout::LongIndex;
527 
531 
532  using Pointer = Element *;
534 
537  layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap, AccessType>;
538 
540  using Mask = typename UnderlyingIterator::Mask;
541 
542  static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
543 
545  class Params {
546  private:
548 
550  typename UnderlyingIterator::Params params_;
551 
552  public:
553 
556  Params() { }
557 
560  Params(Layout const &layout)
561  : params_(layout::PitchLinear(layout.stride(0))){};
562  };
563 
564  private:
565  //
566  // Data members
567  //
568 
570  UnderlyingIterator iterator_;
571 
572  public:
578  Params const &params,
580  Pointer pointer,
582  TensorCoord extent,
584  int thread_id,
586  TensorCoord const &threadblock_offset)
587  : iterator_(params.params_, pointer,
588  layout::PitchLinearCoord(extent.row(), extent.column()),
589  thread_id,
590  layout::PitchLinearCoord(threadblock_offset.row(),
591  threadblock_offset.column())) {}
592 
596  Params const &params,
597  Pointer pointer,
598  TensorCoord extent,
599  int thread_id
600  )
601  : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
602  make_Coord(0, 0)) {}
603 
606  void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
607 
610  void add_pointer_offset(LongIndex pointer_offset) {
611  iterator_.add_pointer_offset(pointer_offset);
612  }
613 
617  void add_tile_offset(TensorCoord const &tile_offset) {
618  iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
619  }
620 
623  AccessType *get() const {
624  return reinterpret_cast<AccessType *>(iterator_.get());
625  }
626 
635  ++iterator_;
636  return *this;
637  }
638 
647  PredicatedTileAccessIterator self(*this);
648  operator++();
649  return self;
650  }
651 
654  void clear_mask() { iterator_.clear_mask(); }
655 
658  void enable_mask() { iterator_.enable_mask(); }
659 
662  void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
663 
666  void get_mask(Mask &mask) { iterator_.get_mask(mask); }
667 
670  bool valid() {
671  return iterator_.valid();
672  }
673 };
674 
676 
684 template <typename Shape_, typename Element_, int AdvanceRank,
685  typename ThreadMap_, typename AccessType_>
686 class PredicatedTileAccessIterator<Shape_, Element_, layout::RowMajor,
687  AdvanceRank, ThreadMap_, AccessType_> {
688  public:
690  AdvanceRank == 0 || AdvanceRank == 1,
691  "Specialization for pitch-linear iterator may along advance along the "
692  "contiguous(rank=0) or strided(rank=1) dimension.");
693 
694  using Shape = Shape_;
695  using Element = Element_;
696  using Layout = layout::RowMajor;
697  static int const kAdvanceRank = AdvanceRank;
698  using ThreadMap = ThreadMap_;
699  using AccessType = AccessType_;
700 
701  using Index = typename Layout::Index;
702  using LongIndex = typename Layout::LongIndex;
703 
707 
708  using Pointer = Element *;
710 
713  layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap, AccessType>;
714 
715  static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
716 
718  using Mask = typename UnderlyingIterator::Mask;
719 
721  class Params {
722  private:
724 
726  typename UnderlyingIterator::Params params_;
727 
728  public:
729 
732  Params() { }
733 
736  Params(Layout const &layout)
737  : params_(layout::PitchLinear(layout.stride(0))){};
738  };
739 
740  private:
741  //
742  // Data members
743  //
744 
746  UnderlyingIterator iterator_;
747 
748  public:
754  Params const &params,
756  Pointer pointer,
758  TensorCoord extent,
760  int thread_id,
762  TensorCoord const &threadblock_offset)
763  : iterator_(params.params_, pointer,
764  layout::PitchLinearCoord(extent.column(), extent.row()),
765  thread_id,
766  layout::PitchLinearCoord(threadblock_offset.column(),
767  threadblock_offset.row())) {}
768 
772  Params const &params,
773  Pointer pointer,
774  TensorCoord extent,
775  int thread_id
776  )
777  : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
778  make_Coord(0, 0)) {}
779 
782  void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
783 
786  void add_pointer_offset(LongIndex pointer_offset) {
787  iterator_.add_pointer_offset(pointer_offset);
788  }
789 
793  void add_tile_offset(TensorCoord const &tile_offset) {
794  iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
795  }
796 
799  AccessType *get() const {
800  return reinterpret_cast<AccessType *>(iterator_.get());
801  }
802 
811  ++iterator_;
812  return *this;
813  }
814 
823  PredicatedTileAccessIterator self(*this);
824  operator++();
825  return self;
826  }
827 
830  void clear_mask() { iterator_.clear_mask(); }
831 
834  void enable_mask() { iterator_.enable_mask(); }
835 
838  void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
839 
842  void get_mask(Mask &mask) { iterator_.get_mask(mask); }
843 
846  bool valid() {
847  return iterator_.valid();
848  }
849 };
850 
852 
861 
862 template <typename Shape_, typename Element_, int AdvanceRank,
863  typename ThreadMap_, typename AccessType_, int InterleavedK>
864 class PredicatedTileAccessIterator<Shape_, Element_,
865  layout::ColumnMajorInterleaved<InterleavedK>,
866  AdvanceRank, ThreadMap_, AccessType_> {
867  public:
869  AdvanceRank == 0 || AdvanceRank == 1,
870  "Specialization for pitch-linear iterator may along advance along the "
871  "contiguous(rank=0) or strided(rank=1) dimension.");
872 
873  using Shape = Shape_;
874  using Element = Element_;
875  static int const kInterleavedK = InterleavedK;
877  static int const kAdvanceRank = AdvanceRank;
878  using ThreadMap = ThreadMap_;
879  using AccessType = AccessType_;
880 
881  using Index = typename Layout::Index;
882  using LongIndex = typename Layout::LongIndex;
883 
887 
888  using Pointer = Element *;
890 
892  layout::PitchLinearShape<Shape::kRow * kInterleavedK,
893  Shape::kColumn / kInterleavedK>,
894  Element, layout::PitchLinear, (kAdvanceRank == 0 ? 0 : 1), ThreadMap,
896 
897  static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
898 
900  using Mask = typename UnderlyingIterator::Mask;
901 
903  class Params {
904  private:
906 
908  typename UnderlyingIterator::Params params_;
909 
910  public:
912  Params() {}
913 
916  Params(Layout const &layout)
917  : params_(layout::PitchLinear(layout.stride(0))) {}
918  };
919 
920  private:
921  //
922  // Data members
923  //
924 
926  UnderlyingIterator iterator_;
927 
928  public:
934  Params const &params,
936  Pointer pointer,
938  TensorCoord extent,
940  int thread_id,
942  TensorCoord const &threadblock_offset)
943  : iterator_(params.params_, pointer,
944  layout::PitchLinearCoord(extent.row() * kInterleavedK,
945  extent.column() / kInterleavedK),
946  thread_id,
947  layout::PitchLinearCoord(
948  threadblock_offset.row() * kInterleavedK,
949  threadblock_offset.column() / kInterleavedK)) {}
950 
954  Params const &params,
955  Pointer pointer,
956  TensorCoord extent,
957  int thread_id
958  )
959  : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
960  make_Coord(0, 0)) {}
961 
964  void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
965 
968  void add_pointer_offset(LongIndex pointer_offset) {
969  iterator_.add_pointer_offset(pointer_offset);
970  }
971 
975  void add_tile_offset(TensorCoord const &tile_offset) {
976  iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
977  }
978 
981  AccessType *get() const {
982  return reinterpret_cast<AccessType *>(iterator_.get());
983  }
984 
993  ++iterator_;
994  return *this;
995  }
996 
1005  PredicatedTileAccessIterator self(*this);
1006  operator++();
1007  return self;
1008  }
1009 
1012  void clear_mask() { iterator_.clear_mask(); }
1013 
1016  void enable_mask() { iterator_.enable_mask(); }
1017 
1020  void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1021 
1024  void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1025 
1028  bool valid() { return iterator_.valid(); }
1029 };
1030 
1032 
1041 template <typename Shape_, typename Element_, int AdvanceRank,
1042  typename ThreadMap_, typename AccessType_, int InterleavedK>
1043 class PredicatedTileAccessIterator<Shape_, Element_,
1044  layout::RowMajorInterleaved<InterleavedK>,
1045  AdvanceRank, ThreadMap_, AccessType_> {
1046  public:
1047  static_assert(
1048  AdvanceRank == 0 || AdvanceRank == 1,
1049  "Specialization for pitch-linear iterator may along advance along the "
1050  "contiguous(rank=0) or strided(rank=1) dimension.");
1051 
1052  using Shape = Shape_;
1053  using Element = Element_;
1054  static int const kInterleavedK = InterleavedK;
1056  static int const kAdvanceRank = AdvanceRank;
1057  using ThreadMap = ThreadMap_;
1058  using AccessType = AccessType_;
1059 
1060  using Index = typename Layout::Index;
1061  using LongIndex = typename Layout::LongIndex;
1062 
1066 
1067  using Pointer = Element *;
1069 
1071  layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
1072  Shape::kRow / kInterleavedK>,
1073  Element, layout::PitchLinear, (kAdvanceRank == 0 ? 1 : 0), ThreadMap,
1075 
1076 
1077  static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
1078 
1080  using Mask = typename UnderlyingIterator::Mask;
1081 
1083  class Params {
1084  private:
1086 
1088  typename UnderlyingIterator::Params params_;
1089 
1090  public:
1092  Params() {}
1093 
1096  Params(Layout const &layout)
1097  : params_(layout::PitchLinear(layout.stride(0))) {}
1098  };
1099 
1100  private:
1101  //
1102  // Data members
1103  //
1104 
1106  UnderlyingIterator iterator_;
1107 
1108  public:
1114  Params const &params,
1116  Pointer pointer,
1118  TensorCoord extent,
1120  int thread_id,
1122  TensorCoord const &threadblock_offset)
1123  : iterator_(params.params_, pointer,
1124  layout::PitchLinearCoord(extent.column() * kInterleavedK,
1125  extent.row() / kInterleavedK),
1126  thread_id,
1127  layout::PitchLinearCoord(
1128  threadblock_offset.column() * kInterleavedK,
1129  threadblock_offset.row() / kInterleavedK)) {}
1130 
1134  Params const &params,
1135  Pointer pointer,
1136  TensorCoord extent,
1137  int thread_id
1138  )
1139  : PredicatedTileAccessIterator(params, pointer, extent, thread_id,
1140  make_Coord(0, 0)) {}
1141 
1144  void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
1145 
1148  void add_pointer_offset(LongIndex pointer_offset) {
1149  iterator_.add_pointer_offset(pointer_offset);
1150  }
1151 
1155  void add_tile_offset(TensorCoord const &tile_offset) {
1156  iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
1157  }
1158 
1161  AccessType *get() const {
1162  return reinterpret_cast<AccessType *>(iterator_.get());
1163  }
1164 
1173  ++iterator_;
1174  return *this;
1175  }
1176 
1185  PredicatedTileAccessIterator self(*this);
1186  operator++();
1187  return self;
1188  }
1189 
1192  void clear_mask() { iterator_.clear_mask(); }
1193 
1196  void enable_mask() { iterator_.enable_mask(); }
1197 
1200  void set_mask(Mask const &mask) { iterator_.set_mask(mask); }
1201 
1204  void get_mask(Mask &mask) { iterator_.get_mask(mask); }
1205 
1208  bool valid() { return iterator_.valid(); }
1209 };
1210 
1212 
1213 } // namespace threadblock
1214 } // namespace transform
1215 } // namespace cutlass
1216 
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:606
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:355
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:842
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:1200
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:932
T type
Definition: platform.h:351
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:249
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:1133
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:610
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:352
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1196
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator.h:793
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:595
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:97
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:709
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:246
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:670
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1016
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:1080
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:964
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:900
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: predicated_tile_access_iterator.h:916
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:480
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Definition: predicated_tile_access_iterator.h:822
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: predicated_tile_access_iterator.h:147
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:830
Array< uint32_t, kPredicateWordCount > Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:119
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:347
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:846
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Definition: predicated_tile_access_iterator.h:646
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Increment and return an instance to self.
Definition: predicated_tile_access_iterator.h:393
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:321
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1192
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:662
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
CUTLASS_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Advances an iterator along logical dimensions of matrix in units of whole tiles.
Definition: predicated_tile_access_iterator.h:353
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator.h:617
Defines a Shape template for matrix tiles.
Defines the size of an element in bits.
Definition: numeric_types.h:42
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:953
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:718
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id)
Construct a PredicatedTileAccessIterator with zero threadblock offset.
Definition: predicated_tile_access_iterator.h:771
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:666
CUTLASS_HOST_DEVICE void get_mask(Mask &mask)
Gets the mask.
Definition: predicated_tile_access_iterator.h:471
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:576
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Definition: predicated_tile_access_iterator.h:634
CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)
std::min
Definition: platform.h:183
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator operator++(int)
Increment and return an instance to self.
Definition: predicated_tile_access_iterator.h:434
#define static_assert(__e, __m)
Definition: platform.h:153
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:1112
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: predicated_tile_access_iterator.h:1096
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:786
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:263
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:1028
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:782
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:1012
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator(Params const &params, Pointer pointer, TensorCoord extent, int thread_id, TensorCoord const &threadblock_offset)
Definition: predicated_tile_access_iterator.h:752
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:658
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:654
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:1148
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:461
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: predicated_tile_access_iterator.h:968
Defines layout functions used by TensorRef and derived classes.
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: predicated_tile_access_iterator.h:560
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:834
Defines layout functions used by TensorRef and derived classes for pitch-linear memory.
Definition: layout/matrix.h:343
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_access_iterator.h:533
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:335
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:1020
CUTLASS_HOST_DEVICE void set_iteration_index(int index)
Overrides the internal iteration index.
Definition: predicated_tile_access_iterator.h:1144
CUTLASS_HOST_DEVICE void enable_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:452
CUTLASS_HOST_DEVICE void clear_mask()
Clears the predicate set efficiently.
Definition: predicated_tile_access_iterator.h:442
CUTLASS_HOST_DEVICE PredicatedTileAccessIterator & operator++()
Definition: predicated_tile_access_iterator.h:810
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_access_iterator.h:540
CUTLASS_HOST_DEVICE bool valid()
Returns whether access is valid or not.
Definition: predicated_tile_access_iterator.h:1208
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Construct the Params object given a pitch-linear tensor&#39;s layout.
Definition: predicated_tile_access_iterator.h:736
Definition: predicated_tile_access_iterator.h:66
CUTLASS_HOST_DEVICE void set_mask(Mask const &mask)
Sets the predicate mask, overriding value stored in predicate iterator.
Definition: predicated_tile_access_iterator.h:838
CUTLASS_HOST_DEVICE void add_tile_offset(TensorCoord const &tile_offset)
Definition: predicated_tile_access_iterator.h:1155
Definition: layout/matrix.h:237