CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
epilogue/threadblock/predicated_tile_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
33 #pragma once
34 
35 #include "cutlass/cutlass.h"
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/array.h"
38 #include "cutlass/layout/matrix.h"
39 #include "cutlass/matrix_shape.h"
40 #include "cutlass/tensor_ref.h"
41 
44 
45 
47 
48 namespace cutlass {
49 
51 
52 namespace epilogue {
53 namespace threadblock {
54 
56 
61 template <
62  typename ThreadMap_,
63  typename Element_
64 >
66 public:
67  using ThreadMap = ThreadMap_;
68  using Shape = typename ThreadMap::Shape;
69 
70  using Element = Element_;
71 
75 
76  using Index = typename Layout::Index;
77  using LongIndex = typename Layout::LongIndex;
79 
80  static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
81  static int const kThreads = ThreadMap::kThreads;
82  static int const kIterations = ThreadMap::Count::kTile;
83 
84  static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0");
85  static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0");
86  static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0");
87  static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0");
88 
90  using Fragment = Array<
91  Element,
92  ThreadMap::Iterations::kColumn *
93  ThreadMap::Iterations::kRow *
94  ThreadMap::Iterations::kGroup *
95  ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
96 
99 
100  //
101  // Parameters struct
102  //
103 
104  struct Params {
105 
106  //
107  // Data members
108  //
109 
111 
115 
120 
121  //
122  // Methods
123  //
124 
127 
128  stride = stride_;
129 
130  increment_row = stride * ThreadMap::Delta::kRow;
131 
132  increment_group = stride * ThreadMap::Delta::kGroup
133  - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
134 
135  increment_cluster = stride * ThreadMap::Delta::kCluster
136  - stride * ThreadMap::Delta::kGroup * (ThreadMap::Iterations::kGroup - 1)
137  - stride * ThreadMap::Delta::kRow * (ThreadMap::Iterations::kRow - 1);
138 
139  advance_row = stride * ThreadMap::Shape::kRow;
140 
141  advance_group = stride * (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
142 
143  advance_cluster =
144  stride *
145  ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;;
146 
147  advance_tile =
148  stride *
149  ThreadMap::Shape::kGroup *
150  ThreadMap::Shape::kRow *
151  ThreadMap::Shape::kCluster *
152  ThreadMap::Shape::kTile;
153 
154  return Status::kSuccess;
155  }
156 
158  Params() {
159  initialize(0);
160  }
161 
163  Params(Layout const &layout) {
164 
165  initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess);
166  }
167  };
168 
170  struct Mask {
171 
172  static int const kCount = ThreadMap::Iterations::kColumn;
173 
175  bool predicates[kCount];
176 
177  //
178  // Mask
179  //
181  Mask() {
182  enable();
183  }
184 
188  for (int i = 0; i < kCount; ++i) {
189  predicates[i] = false;
190  }
191  }
192 
194  CUTLASS_DEVICE void enable() {
196  for (int i = 0; i < kCount; ++i) {
197  predicates[i] = true;
198  }
199  }
200  };
201 
202 private:
203 
204  //
205  // Data members
206  //
207 
209  Params params_;
210 
212  uint8_t *byte_pointer_;
213 
215  Mask mask_;
216 
218  Index extent_row_;
219 
221  Index thread_start_row_;
222 
224  int state_[3];
225 
226 private:
227 
228  //
229  // Methods
230  //
231 
232 public:
233 
234  //
235  // Methods
236  //
237 
239  CUTLASS_DEVICE
241  Params const & params,
242  Element *pointer,
243  TensorCoord extent,
244  int thread_idx,
245  TensorCoord threadblock_offset = TensorCoord()
246  ):
247  params_(params) {
248 
249  TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
250 
251  extent_row_ = extent.row();
252  thread_start_row_ = thread_offset.row();
253 
254  // Initialize predicates
256  for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
257 
258  mask_.predicates[c] = ((thread_offset.column()
259  + ThreadMap::Delta::kColumn * c) < extent.column());
260  }
261 
262  // Initialize pointer
263  byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
264  thread_offset.row() * params_.stride +
265  thread_offset.column() * sizeof(AccessType) / kElementsPerAccess;
266 
267  // Initialize internal state counter
268  state_[0] = state_[1] = state_[2] = 0;
269  }
270 
273  void add_pointer_offset(LongIndex pointer_offset) {
274  byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
275  }
276 
278  CUTLASS_DEVICE
279  void load(Fragment &frag) {
280 
281  uint8_t *byte_pointer = byte_pointer_;
282  AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
283 
285  for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
286 
288  for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
289 
291  for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
292 
293  int frag_row_idx =
294  (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
295 
296  int row_offset = row * ThreadMap::Delta::kRow
297  + group * ThreadMap::Delta::kGroup
298  + cluster * ThreadMap::Delta::kCluster;
299 
300  bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
301 
302  AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
303 
305  for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
306 
307  bool guard = row_guard && mask_.predicates[column];
308 
309  if (guard) {
310  frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] =
311  memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess];
312  }
313  }
314 
315  if (row + 1 < ThreadMap::Iterations::kRow) {
316  byte_pointer += params_.increment_row;
317  }
318  }
319 
320  if (group + 1 < ThreadMap::Iterations::kGroup) {
321  byte_pointer += params_.increment_group;
322  }
323  }
324 
325  if (cluster + 1 < ThreadMap::Iterations::kCluster) {
326  byte_pointer += params_.increment_cluster;
327  }
328  }
329  }
330 
332  CUTLASS_DEVICE
333  void store(Fragment const &frag) {
334  uint8_t *byte_pointer = byte_pointer_;
335  AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
336 
338  for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
339 
341  for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
342 
344  for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
345 
346  int frag_row_idx =
347  (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
348 
349  int row_offset = row * ThreadMap::Delta::kRow
350  + group * ThreadMap::Delta::kGroup
351  + cluster * ThreadMap::Delta::kCluster;
352 
353  bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
354 
355  AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
356 
358  for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
359 
360  bool guard = row_guard && mask_.predicates[column];
361 
362  if (guard) {
363 
364  memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
365  frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
366  }
367  }
368 
369  if (row + 1 < ThreadMap::Iterations::kRow) {
370  byte_pointer += params_.increment_row;
371  }
372  }
373 
374  if (group + 1 < ThreadMap::Iterations::kGroup) {
375  byte_pointer += params_.increment_group;
376  }
377  }
378 
379  if (cluster + 1 < ThreadMap::Iterations::kCluster) {
380  byte_pointer += params_.increment_cluster;
381  }
382  }
383  }
384 
388 
389  ++state_[0];
390  byte_pointer_ += params_.advance_row;
391  thread_start_row_ += ThreadMap::Shape::kRow;
392 
393  if (state_[0] == ThreadMap::Count::kRow) {
394 
395  state_[0] = 0;
396  ++state_[1];
397  byte_pointer_ += params_.advance_group;
398 
399  thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
400  ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
401 
402  if (state_[1] == ThreadMap::Count::kGroup) {
403 
404  state_[1] = 0;
405  ++state_[2];
406  byte_pointer_ += params_.advance_cluster;
407 
408  thread_start_row_ += ThreadMap::Count::kGroup *
409  ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow;
410 
411  if (state_[2] == ThreadMap::Count::kCluster) {
412  state_[2] = 0;
413  byte_pointer_ += params_.advance_tile;
414  }
415  }
416  }
417 
418  return *this;
419  }
420 
422  CUTLASS_DEVICE void clear_mask() {
423  mask_.clear();
424  }
425 
427  CUTLASS_DEVICE void enable_mask() {
428  mask_.enable();
429  }
430 
432  CUTLASS_DEVICE void get_mask(Mask &mask) {
433  return mask_;
434  }
435 
437  CUTLASS_DEVICE void set_mask(Mask const &mask) {
438  mask_ = mask;
439  }
440 };
441 
447 template <
448  typename ThreadMap_,
449  typename Element_,
450  int InterleavedK
451 >
453 public:
454  using ThreadMap = ThreadMap_;
455 
456  using Element = Element_;
457 
461 
462  using Index = typename Layout::Index;
463  using LongIndex = typename Layout::LongIndex;
465 
466  static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
467  static int const kThreads = ThreadMap::kThreads;
468  static int const kIterations = ThreadMap::Iterations::kCount;
469 
471  using Fragment = Array<Element, ThreadMap::kElementsPerAccess>;
472 
475 
476  //
477  // Parameters struct
478  //
479 
480  struct Params {
481 
482  //
483  // Data members
484  //
485 
487 
490 
491  //
492  // Methods
493  //
494 
497  stride = stride_;
498 
499  advance_row =
500  ThreadMap::Delta::kContiguous * sizeof_bits<Element>::value / 8;
501 
502  advance_column =
503  stride_ - ThreadMap::Iterations::kContiguous * kElementsPerAccess *
504  sizeof_bits<Element>::value * ThreadMap::kWarpSize / 8;
505 
506  return Status::kSuccess;
507  }
508 
510  Params() {
511  initialize(0);
512  }
513 
515  Params(Layout const &layout) {
516 
517  initialize(layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess);
518  }
519  };
520 
522  struct Mask {
523  static int const kCount = (ThreadMap::Iterations::kContiguous < 8)
524  ? 8
525  : ThreadMap::Iterations::kContiguous;
526 
528  bool predicates[kCount];
529 
530  //
531  // Mask
532  //
534  Mask() {
535  enable();
536  }
537 
541  for (int i = 0; i < kCount; ++i) {
542  predicates[i] = false;
543  }
544  }
545 
547  CUTLASS_DEVICE void enable() {
549  for (int i = 0; i < kCount; ++i) {
550  predicates[i] = true;
551  }
552  }
553  };
554 
555 private:
556 
557  //
558  // Data members
559  //
560 
562  Params params_;
563 
565  uint8_t *byte_pointer_;
566 
568  Mask mask_;
569 
571  Index extent_col_;
572 
575  Index thread_start_col_;
576 
578  int iteration_contiguous_;
579 
580  int iteration_strided_;
581 
582 private:
583 
584  //
585  // Methods
586  //
587 
588 public:
589 
590  //
591  // Methods
592  //
593 
595  CUTLASS_DEVICE
597  Params const & params,
598  Element *pointer,
599  TensorCoord extent,
600  int thread_idx,
601  TensorCoord threadblock_offset
602  ):
603  params_(params) {
604  TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) +
605  TensorCoord(threadblock_offset.contiguous() * InterleavedK,
606  threadblock_offset.strided() / InterleavedK);
607 
608  extent_col_ = extent.strided() / InterleavedK;
609  thread_start_col_ = thread_offset.strided();
610 
611  // Initialize predicates
613  for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
614  mask_.predicates[c] =
615  ((thread_offset.contiguous() + ThreadMap::Delta::kContiguous * c) <
616  (extent.contiguous() * InterleavedK));
617  }
618 
619  // Initialize pointer
620  byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
621  thread_offset.strided() * params_.stride +
622  thread_offset.contiguous() * sizeof(AccessType) / kElementsPerAccess;
623 
624  // Initialize internal state counter
625  iteration_contiguous_ = iteration_strided_ = 0;
626  }
627 
630  void add_pointer_offset(LongIndex pointer_offset) {
631  byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
632  }
633 
635  CUTLASS_DEVICE
636  void load(Fragment &frag) {
637  uint8_t *byte_pointer = byte_pointer_;
638  AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
639  AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
640 
641  int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
642 
643  bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
644 
645  bool guard = col_guard && mask_.predicates[iteration_contiguous_];
646 
647  if (guard) {
648  *frag_ptr = *memory_pointer;
649  }
650  }
651 
653  CUTLASS_DEVICE
654  void store(Fragment const &frag) {
655  uint8_t *byte_pointer = byte_pointer_;
656  AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
657  AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer);
658 
659  int col_offset = iteration_strided_ * ThreadMap::Delta::kStrided;
660 
661  bool col_guard = ((thread_start_col_ + col_offset) < extent_col_);
662 
663  bool guard = col_guard && mask_.predicates[iteration_contiguous_];
664 
665  if (guard) {
666  *memory_pointer = *frag_ptr;
667  }
668  }
669 
672  void set_iteration_index(int iteration) {
673  iteration_contiguous_ = iteration % ThreadMap::Iterations::kContiguous;
674  iteration_strided_ = iteration / ThreadMap::Iterations::kContiguous;
675  }
676 
680 
681  ++iteration_contiguous_;
682  byte_pointer_ += params_.advance_row;
683 
684  if (iteration_contiguous_ == ThreadMap::Iterations::kContiguous) {
685 
686  iteration_contiguous_ = 0;
687  ++iteration_strided_;
688  byte_pointer_ += params_.advance_column;
689 
690  if (iteration_strided_ == ThreadMap::Iterations::kStrided) {
691  iteration_strided_ = 0;
692  }
693  }
694 
695  return *this;
696  }
697 
699  CUTLASS_DEVICE void clear_mask() {
700  mask_.clear();
701  }
702 
704  CUTLASS_DEVICE void enable_mask() {
705  mask_.enable();
706  }
707 
709  CUTLASS_DEVICE void get_mask(Mask &mask) {
710  return mask_;
711  }
712 
714  CUTLASS_DEVICE void set_mask(Mask const &mask) {
715  mask_ = mask;
716  }
717 };
718 
720 
721 } // namespace threadblock
722 } // namespace epilogue
723 } // namespace cutlass
724 
bool predicates[kCount]
Predicate state.
Definition: epilogue/threadblock/predicated_tile_iterator.h:175
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
static int const kElementsPerAccess
Definition: epilogue/threadblock/predicated_tile_iterator.h:80
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:194
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_coord.h:85
Index advance_row
amount to add to move to the next &#39;row&#39; position
Definition: epilogue/threadblock/predicated_tile_iterator.h:116
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:279
Element_ Element
Definition: epilogue/threadblock/predicated_tile_iterator.h:70
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
Defines a structure containing strides, bounds, and a pointer to tensor data.
AlignedArray< Element, ThreadMap::kElementsPerAccess > AccessType
Memory access size.
Definition: epilogue/threadblock/predicated_tile_iterator.h:98
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:496
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:539
Templates implementing how threads are mapped to a given tile.
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:432
Array< Element, ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:471
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:454
Aligned array type.
Definition: array.h:511
Mask object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:170
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_coord.h:77
bool predicates[kCount]
Predicate state.
Definition: epilogue/threadblock/predicated_tile_iterator.h:528
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:74
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:534
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: layout/matrix.h:418
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: layout/matrix.h:112
Array< Element, ThreadMap::Iterations::kColumn *ThreadMap::Iterations::kRow *ThreadMap::Iterations::kGroup *ThreadMap::Iterations::kCluster *ThreadMap::kElementsPerAccess > Fragment
Fragment object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:95
Definition: epilogue/threadblock/predicated_tile_iterator.h:480
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:333
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:77
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data.
Definition: tensor_ref.h:179
CUTLASS_DEVICE void load(Fragment &frag)
Loads a fragment from memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:636
CUTLASS_DEVICE InterleavedPredicatedTileIterator(Params const &params, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset)
Constructor.
Definition: epilogue/threadblock/predicated_tile_iterator.h:596
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:462
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
Index advance_cluster
amount to add to move to the next &#39;cluster&#39; position
Definition: epilogue/threadblock/predicated_tile_iterator.h:118
Defines a Shape template for matrix tiles.
CUTLASS_DEVICE void store(Fragment const &frag)
Stores a fragment to memory.
Definition: epilogue/threadblock/predicated_tile_iterator.h:654
Defines the size of an element in bits.
Definition: numeric_types.h:42
Index advance_row
amount to add to move to the next &#39;row&#39; position
Definition: epilogue/threadblock/predicated_tile_iterator.h:488
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:699
ThreadMap_ ThreadMap
Definition: epilogue/threadblock/predicated_tile_iterator.h:67
Mask object.
Definition: epilogue/threadblock/predicated_tile_iterator.h:522
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:510
CUTLASS_DEVICE PredicatedTileIterator(Params const &params, Element *pointer, TensorCoord extent, int thread_idx, TensorCoord threadblock_offset=TensorCoord())
Constructor.
Definition: epilogue/threadblock/predicated_tile_iterator.h:240
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Index const & contiguous() const
Returns the contiguous dimension.
Definition: pitch_linear.h:89
Definition: epilogue/threadblock/predicated_tile_iterator.h:452
typename TensorRef::ConstTensorRef ConstTensorRef
Definition: epilogue/threadblock/predicated_tile_iterator.h:460
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_HOST_DEVICE void set_iteration_index(int iteration)
Overrides the internal iteration index.
Definition: epilogue/threadblock/predicated_tile_iterator.h:672
Index stride
stride in bytes between rows
Definition: epilogue/threadblock/predicated_tile_iterator.h:110
CUTLASS_HOST_DEVICE PredicatedTileIterator & operator++()
Advances to the next position to load or store.
Definition: epilogue/threadblock/predicated_tile_iterator.h:387
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:163
Index stride
stride in bytes between columns
Definition: epilogue/threadblock/predicated_tile_iterator.h:486
Index advance_column
amount to add to move to the next &#39;column&#39; position
Definition: epilogue/threadblock/predicated_tile_iterator.h:489
Definition: epilogue/threadblock/predicated_tile_iterator.h:104
static int const kIterations
Definition: epilogue/threadblock/predicated_tile_iterator.h:82
Index advance_tile
amount to add to move to the next &#39;tile&#39;
Definition: epilogue/threadblock/predicated_tile_iterator.h:119
Metaprogram for determining the mapping of output elements to threads for epilogue tiles...
CUTLASS_HOST_DEVICE InterleavedPredicatedTileIterator & operator++()
Advances to the next position to load or store.
Definition: epilogue/threadblock/predicated_tile_iterator.h:679
CUTLASS_DEVICE void clear_mask()
Efficiently enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:422
Index increment_group
increment quantity (in bytes) to advance when moving to the next group
Definition: epilogue/threadblock/predicated_tile_iterator.h:113
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
typename Layout::Index Index
Definition: epilogue/threadblock/predicated_tile_iterator.h:76
Definition: epilogue/threadblock/predicated_tile_iterator.h:65
typename Layout::LongIndex LongIndex
Definition: epilogue/threadblock/predicated_tile_iterator.h:463
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:714
CUTLASS_HOST_DEVICE Params(Layout const &layout)
Definition: epilogue/threadblock/predicated_tile_iterator.h:515
CUTLASS_DEVICE void set_mask(Mask const &mask)
Definition: epilogue/threadblock/predicated_tile_iterator.h:437
CUTLASS_DEVICE void enable()
Definition: epilogue/threadblock/predicated_tile_iterator.h:547
Index advance_group
amount to add to move to the next &#39;group&#39; position
Definition: epilogue/threadblock/predicated_tile_iterator.h:117
Defines layout functions used by TensorRef and derived classes.
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: epilogue/threadblock/predicated_tile_iterator.h:630
CUTLASS_HOST_DEVICE Status initialize(Index stride_)
Definition: epilogue/threadblock/predicated_tile_iterator.h:126
Operation was successful.
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex pointer_offset)
Adds a pointer offset in units of Element.
Definition: epilogue/threadblock/predicated_tile_iterator.h:273
typename ThreadMap::Shape Shape
Definition: epilogue/threadblock/predicated_tile_iterator.h:68
Definition: layout/matrix.h:343
MatrixCoord TensorCoord
Definition: epilogue/threadblock/predicated_tile_iterator.h:78
CUTLASS_HOST_DEVICE Params()
Definition: epilogue/threadblock/predicated_tile_iterator.h:158
CUTLASS_DEVICE void enable_mask()
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:427
Index increment_row
increment quantity (in bytes) to advance when moving between rows
Definition: epilogue/threadblock/predicated_tile_iterator.h:112
Index increment_cluster
increment quantity (in bytes) to advance when moving to the next cluster
Definition: epilogue/threadblock/predicated_tile_iterator.h:114
CUTLASS_DEVICE void get_mask(Mask &mask)
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:709
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE void clear()
CUTLASS_HOST_DEVICE enables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:186
CUTLASS_HOST_DEVICE Index const & strided() const
Returns the column of the coordinate.
Definition: pitch_linear.h:97
Status
Status code returned by CUTLASS operations.
Definition: cutlass.h:39
CUTLASS_DEVICE void enable_mask()
Sets the mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:704
CUTLASS_HOST_DEVICE Mask()
Efficiently disables all accesses guarded by mask.
Definition: epilogue/threadblock/predicated_tile_iterator.h:181
static int const kThreads
Definition: epilogue/threadblock/predicated_tile_iterator.h:81