CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
predicated_tile_iterator_2dthreadtile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
35 #pragma once
36 
39 
41 
42 namespace cutlass {
43 namespace transform {
44 namespace threadblock {
45 
47 
84 // template <typename Iterator>
85 // __global__ void kernel(
86 // typename Iterator::Params params,
87 // typename Iterator::Element *ptr,
88 // TensorCoord extent) {
89 //
90 // typename Iterator::Fragment fragment;
91 //
92 // TensorCoord threadblock_offset(0, 0);
93 //
94 // Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
95 //
96 //
97 // fragment = *iter; // load "residue" tile first
98 // ++iter; // advance to first "steady state" tile and update internal masks
99 //
100 //
101 // #pragma unroll
102 // for (int i = Remaining - 1; i >= 0; --i) {
103 //
104 // f(fragment);
105 //
106 // if (!i) {
107 // iter.clear_mask(); // light-weight operation to clear masks - subsequent loads become NO-OPs.
108 // }
109 //
110 // fragment = *iter; // load tile during "steady state" phase
111 // ++iter; // advance to next tile - lightweight due to steady-state masks
112 // }
113 // }
114 //
115 // void host(TensorView<Element, 2, layout::PitchLinear> view) {
116 //
117 // using Iterator = transform::threadblock::PredicatedTileIterator2dThreadTile;
118 //
119 // typename Iterator::Params params(view.layout());
120 //
121 // kernel<Iterator>(params, view.data());
122 // }
125 template <
126  typename Shape,
127  typename Element,
128  typename Layout,
129  int AdvanceRank,
130  typename ThreadMap,
131  bool Transpose = false
132 >
134 
136 
144 template <typename Shape_, typename Element_, int AdvanceRank, typename ThreadMap_, bool Transpose_>
145 class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear, AdvanceRank, ThreadMap_, Transpose_> {
146  public:
148  AdvanceRank == 0 || AdvanceRank == 1,
149  "Specialization for pitch-linear iterator may along advance along the "
150  "contiguous(rank=0) or strided(rank=1) dimension.");
151 
152  using Shape = Shape_;
153  using Element = Element_;
154  using Layout = layout::PitchLinear;
155  static int const kAdvanceRank = AdvanceRank;
156  using ThreadMap = ThreadMap_;
157 
158  using Index = typename Layout::Index;
159  using LongIndex = typename Layout::LongIndex;
160 
164 
165  using Pointer = Element *;
167 
170  struct alignas((ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value /
172  8)) AccessType {
174  Array<Element, ThreadMap::kElementsPerAccess> storage;
176  static int const kElements = ThreadMap::kElementsPerAccess;
177  };
178 
181  static bool const transpose = Transpose_;
182 
184  using TileAccessIterator =
185  PredicatedTileAccessIterator2dThreadTile<Shape, Element, Layout, kAdvanceRank,
186  ThreadMap, AccessType>;
187 
189  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount *
190  ThreadMap::ThreadAccessShape::kCount>;
191 
193  using Mask = typename TileAccessIterator::Mask;
194 
196  class Params {
197  public:
199 
200  private:
202  typename TileAccessIterator::Params params_;
203 
204  public:
207  Params(Layout const &layout) : params_(layout) { }
208 
210  Params() { }
211  };
212 
213  private:
215  using BytePointer = char *;
216 
217  private:
218  //
219  // Data members
220  //
221 
223  TileAccessIterator address_iterator_;
224 
225  public:
231  Params const &params,
233  Pointer pointer,
235  TensorCoord extent,
237  int thread_id,
239  TensorCoord const &threadblock_offset)
240  : address_iterator_(params.params_, pointer, extent, thread_id,
241  threadblock_offset) {}
242 
246  Params const &params,
247  Pointer pointer,
248  TensorCoord extent,
249  int thread_id
250  )
251  : PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id,
252  make_Coord(0, 0)) {}
253 
256  void add_pointer_offset(LongIndex pointer_offset) {
257  address_iterator_.add_pointer_offset(pointer_offset);
258  }
259 
268  if (kAdvanceRank)
269  address_iterator_.add_tile_offset({0, 1});
270  else
271  address_iterator_.add_tile_offset({1, 0});
272 
273  return *this;
274  }
275 
285  operator++();
286  return self;
287  }
288 
291  void clear_mask() { address_iterator_.clear_mask(); }
292 
295  void enable_mask() { address_iterator_.enable_mask(); }
296 
299  void set_mask(Mask const &mask) { address_iterator_.set_mask(mask); }
300 
303  void get_mask(Mask &mask) { address_iterator_.get_mask(mask); }
304 
306  CUTLASS_DEVICE
307  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
308 
309  AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
310 
312  for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
314  for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
316  for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
317 
318  int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
319  s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
320 
321  address_iterator_.set_iteration_index(access_idx);
322  if (address_iterator_.valid()) {
323 
324  frag_ptr[access_idx] =
325  *(address_iterator_.get() + pointer_offset);
326  }
327 
328  ++address_iterator_;
329  }
330  }
331  }
332 
333  if (transpose) {
334  Transform t;
335  t.transform(frag, frag);
336  }
337  }
338 
340  CUTLASS_DEVICE
341  void load(Fragment &frag) { load_with_pointer_offset(frag, 0); }
342 
344  CUTLASS_DEVICE
345  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
346 
347  AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
348 
350  for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
352  for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
354  for (int ts = 0; ts < ThreadMap::ThreadAccessShape::kStrided; ts++){
355 
356  int access_idx = ts + c * ThreadMap::ThreadAccessShape::kStrided + \
357  s * ThreadMap::Iterations::kContiguous * ThreadMap::ThreadAccessShape::kStrided;
358 
359  address_iterator_.set_iteration_index(access_idx);
360  if (address_iterator_.valid()) {
361  *(address_iterator_.get() + pointer_offset) = frag_ptr[access_idx];
362  }
363  ++address_iterator_;
364  }
365  }
366  }
367  }
368 
370  CUTLASS_DEVICE
371  void store(Fragment const &frag) { store_with_pointer_offset(frag, 0); }
372 };
373 
375 
383 template <
384  typename Shape_,
385  typename Element_,
386  int AdvanceRank,
387  typename ThreadMap_,
388  bool Transpose_
389 >
390 class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::ColumnMajor, AdvanceRank, ThreadMap_, Transpose_> {
391 public:
392 
393  static_assert(AdvanceRank == 0 || AdvanceRank == 1,
394  "Specialization for pitch-linear iterator may along advance along the "
395  "contiguous(rank=0) or strided(rank=1) dimension.");
397  using Shape = Shape_;
398  using Element = Element_;
399  using Layout = layout::ColumnMajor;
400  static int const kAdvanceRank = AdvanceRank;
401  using ThreadMap = ThreadMap_;
402  static bool const Transpose = Transpose_;
404  using Index = typename Layout::Index;
405  using LongIndex = typename Layout::LongIndex;
409  using TensorCoord = typename Layout::TensorCoord;
411  using Pointer = Element *;
413 
416  Element,
418  (kAdvanceRank == 0 ? 0 : 1),
419  ThreadMap,
420  Transpose
421  >;
423  using AccessType = typename UnderlyingIterator::AccessType;
424 
426  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
427 
429  using Mask = typename UnderlyingIterator::Mask;
430 
432  class Params {
433  private:
434 
436 
438  typename UnderlyingIterator::Params params_;
439 
440  public:
441 
443  Params() { }
444 
447  Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
448 
449  }
450  };
451 
452 
453 private:
454 
455  //
456  // Data members
457  //
458 
460  UnderlyingIterator iterator_;
461 
462 public:
463 
467  Params const &params,
468  Pointer pointer,
469  TensorCoord extent,
470  int thread_id,
471  TensorCoord const &threadblock_offset
472  ):
473  iterator_(
474  params.params_,
475  pointer,
476  layout::PitchLinearCoord(extent.row(), extent.column()),
477  thread_id,
478  layout::PitchLinearCoord(threadblock_offset.row(), threadblock_offset.column())
479  ) { }
480 
484  Params const &params,
485  Pointer pointer,
486  TensorCoord extent,
487  int thread_id
488  ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
489 
492  void add_pointer_offset(LongIndex pointer_offset) {
493  iterator_.add_pointer_offset(pointer_offset);
494  }
495 
503  ++iterator_;
504  return *this;
505  }
506 
515  operator++();
516  return self;
517  }
518 
521  void clear_mask() {
522  iterator_.clear_mask();
523  }
524 
527  void enable_mask() {
528  iterator_.enable_mask();
529  }
530 
533  void set_mask(Mask const &mask) {
534  iterator_.set_mask(mask);
535  }
536 
539  void get_mask(Mask &mask) {
540  iterator_.get_mask(mask);
541  }
542 
544  CUTLASS_DEVICE
545  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
546  iterator_.load_with_pointer_offset(frag, pointer_offset);
547  }
548 
550  CUTLASS_DEVICE
551  void load(Fragment &frag) {
552  load_with_pointer_offset(frag, 0);
553  }
554 
556  CUTLASS_DEVICE
557  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
558  iterator_.store_with_pointer_offset(frag, pointer_offset);
559  }
560 
562  CUTLASS_DEVICE
563  void store(Fragment const &frag) {
564  store_with_pointer_offset(frag, 0);
565  }
566 };
567 
569 
577 template <
578  typename Shape_,
579  typename Element_,
580  int AdvanceRank,
581  typename ThreadMap_,
582  bool Transpose_
583 >
584 class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::RowMajor, AdvanceRank, ThreadMap_, Transpose_> {
585 public:
586 
587  static_assert(AdvanceRank == 0 || AdvanceRank == 1,
588  "Specialization for pitch-linear iterator may along advance along the "
589  "contiguous(rank=0) or strided(rank=1) dimension.");
591  using Shape = Shape_;
592  using Element = Element_;
593  using Layout = layout::RowMajor;
594  static int const kAdvanceRank = AdvanceRank;
595  using ThreadMap = ThreadMap_;
596  static bool const Transpose = Transpose_;
598  using Index = typename Layout::Index;
599  using LongIndex = typename Layout::LongIndex;
603  using TensorCoord = typename Layout::TensorCoord;
605  using Pointer = Element *;
607 
610  Element,
612  (kAdvanceRank == 0 ? 1 : 0),
613  ThreadMap,
614  Transpose
615  >;
617  using AccessType = typename UnderlyingIterator::AccessType;
618 
620  using Fragment = cutlass::Array<Element, ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount>;
621 
623  using Mask = typename UnderlyingIterator::Mask;
624 
626  class Params {
627  private:
628 
630 
632  typename UnderlyingIterator::Params params_;
633 
634  public:
635 
637  Params() { }
638 
641  Params(Layout const &layout): params_(layout::PitchLinear(layout.stride(0))) {
642 
643  };
644  };
645 
646 
647 private:
648 
649  //
650  // Data members
651  //
652 
654  UnderlyingIterator iterator_;
655 
656 public:
657 
661  Params const &params,
662  Pointer pointer,
663  TensorCoord extent,
664  int thread_id,
665  TensorCoord const &threadblock_offset
666  ):
667  iterator_(
668  params.params_,
669  pointer,
670  layout::PitchLinearCoord(extent.column(), extent.row()),
671  thread_id,
672  layout::PitchLinearCoord(threadblock_offset.column(), threadblock_offset.row())
673  ) { }
674 
678  Params const &params,
679  Pointer pointer,
680  TensorCoord extent,
681  int thread_id
682  ): PredicatedTileIterator2dThreadTile(params, pointer, extent, thread_id, make_Coord(0, 0)) { }
683 
686  void add_pointer_offset(LongIndex pointer_offset) {
687  iterator_.add_pointer_offset(pointer_offset);
688  }
689 
697  ++iterator_;
698  return *this;
699  }
700 
709  operator++();
710  return self;
711  }
712 
715  void clear_mask() {
716  iterator_.clear_mask();
717  }
718 
721  void enable_mask() {
722  iterator_.enable_mask();
723  }
724 
727  void set_mask(Mask const &mask) {
728  iterator_.set_mask(mask);
729  }
730 
733  void get_mask(Mask &mask) {
734  iterator_.get_mask(mask);
735  }
736 
738  CUTLASS_DEVICE
739  void load_with_pointer_offset(Fragment &frag, Index pointer_offset) {
740  iterator_.load_with_pointer_offset(frag, pointer_offset);
741  }
742 
744  CUTLASS_DEVICE
745  void load(Fragment &frag) {
746  load_with_pointer_offset(frag, 0);
747  }
748 
750  CUTLASS_DEVICE
751  void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) {
752  iterator_.store_with_pointer_offset(frag, pointer_offset);
753  }
754 
756  CUTLASS_DEVICE
757  void store(Fragment const &frag) {
758  store_with_pointer_offset(frag, 0);
759  }
760 };
761 
763 
764 } // namespace threadblock
765 } // namespace transform
766 } // namespace cutlass
767 
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:62
Definition: aligned_buffer.h:35
Coordinate in pitch-linear space.
Definition: pitch_linear.h:52
T type
Definition: platform.h:351
Basic copy routines for tensor views.
Mapping function for pitch-linear memory.
Definition: pitch_linear.h:163
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
int64_t LongIndex
Long index type used for offsets.
Definition: layout/matrix.h:154
typename UnderlyingIterator::AccessType AccessType
Definition: predicated_tile_iterator_2dthreadtile.h:422
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::ThreadAccessShape::kCount > Fragment
Fragment object to be loaded or stored.
Definition: predicated_tile_iterator_2dthreadtile.h:619
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
Template defining a shape used by pitch-linear operators.
Definition: pitch_linear.h:43
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:59
CUTLASS_HOST_DEVICE half_t & operator++(half_t &lhs)
Definition: half.h:694
int64_t LongIndex
Long index type used for offsets.
Definition: pitch_linear.h:175
typename TileAccessIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_iterator_2dthreadtile.h:192
typename UnderlyingIterator::AccessType AccessType
Definition: predicated_tile_iterator_2dthreadtile.h:616
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Transforms a fragment by doing a transpose.
Definition: transpose.h:39
#define static_assert(__e, __m)
Definition: platform.h:153
int32_t Index
Index type used for coordinates.
Definition: pitch_linear.h:172
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_iterator_2dthreadtile.h:428
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_iterator_2dthreadtile.h:605
int32_t Index
Index type used for coordinates.
Definition: layout/matrix.h:151
typename UnderlyingIterator::Mask Mask
Predicate vector stores mask to guard accesses.
Definition: predicated_tile_iterator_2dthreadtile.h:622
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_iterator_2dthreadtile.h:166
Templates calculating the address and predicates to the load of tiles from pitch-linear rank=2 tensor...
Definition: predicated_tile_iterator_2dthreadtile.h:133
typename platform::remove_const< Element >::type * NonConstPointer
Definition: predicated_tile_iterator_2dthreadtile.h:411
Definition: matrix_coord.h:39
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::ThreadAccessShape::kCount > Fragment
Fragment object to be loaded or stored.
Definition: predicated_tile_iterator_2dthreadtile.h:189
cutlass::Array< Element, ThreadMap::Iterations::kCount *ThreadMap::ThreadAccessShape::kCount > Fragment
Fragment object to be loaded or stored.
Definition: predicated_tile_iterator_2dthreadtile.h:425