CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
array_subbyte.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 #include "cutlass/cutlass.h"
33 #include "cutlass/array.h"
35 
36 namespace cutlass {
37 
39 
41 template <
42  typename T,
43  int N
44 >
45 class Array<T, N, false> {
46 public:
47 
49  "Array<> specialized for sub-byte types assume the actual stored element size is 1 byte");
50 
51  static int const kSizeBits = sizeof_bits<T>::value * N;
52 
54  using Storage = typename platform::conditional<
55  ((kSizeBits % 32) != 0),
56  typename platform::conditional<
57  ((kSizeBits % 16) != 0),
58  uint8_t,
59  uint16_t
60  >::type,
61  uint32_t
62  >::type;
63 
65  using Element = T;
66 
68  static int const kElementsPerStoredItem = (sizeof(Storage) * 8) / sizeof_bits<T>::value;
69 
71  static size_t const kStorageElements = N / kElementsPerStoredItem;
72 
74  static size_t const kElements = N;
75 
77  static Storage const kMask = ((Storage(1) << sizeof_bits<T>::value) - 1);
78 
79  //
80  // C++ standard members with pointer types removed
81  //
82 
83  typedef T value_type;
84  typedef size_t size_type;
85  typedef ptrdiff_t difference_type;
86  typedef value_type *pointer;
87  typedef value_type const *const_pointer;
88 
89  //
90  // References
91  //
92 
94  class reference {
96  Storage *ptr_;
97 
99  int idx_;
100 
101  public:
102 
105  reference(): ptr_(nullptr), idx_(0) { }
106 
109  reference(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
110 
113  reference &operator=(T x) {
114  Storage item = (reinterpret_cast<Storage const &>(x) & kMask);
115 
116  Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits<T>::value)));
117  *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits<T>::value)));
118 
119  return *this;
120  }
121 
123  T get() const {
124  Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask);
125  return reinterpret_cast<T const &>(item);
126  }
127 
130  operator T() const {
131  return get();
132  }
133 
136  explicit operator int() const {
137  return int(get());
138  }
139 
142  explicit operator float() const {
143  return float(get());
144  }
145  };
146 
148  class const_reference {
149 
151  Storage const *ptr_;
152 
154  int idx_;
155 
156  public:
157 
160  const_reference(): ptr_(nullptr), idx_(0) { }
161 
164  const_reference(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
165 
167  const T get() const {
168  Storage item = (*ptr_ >> (idx_ * sizeof_bits<T>::value)) & kMask;
169  return reinterpret_cast<T const &>(item);
170  }
171 
174  operator T() const {
175  Storage item = Storage(Storage(*ptr_ >> Storage(idx_ * sizeof_bits<T>::value)) & kMask);
176  return reinterpret_cast<T const &>(item);
177  }
178 
181  explicit operator int() const {
182  return int(get());
183  }
184 
187  explicit operator float() const {
188  return float(get());
189  }
190  };
191 
192  //
193  // Iterators
194  //
195 
197  class iterator {
198 
200  Storage *ptr_;
201 
203  int idx_;
204 
205  public:
206 
208  iterator(): ptr_(nullptr), idx_(0) { }
209 
211  iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
212 
214  iterator &operator++() {
215  ++idx_;
216  if (idx_ == kElementsPerStoredItem) {
217  ++ptr_;
218  idx_ = 0;
219  }
220  return *this;
221  }
222 
224  iterator &operator--() {
225  if (!idx_) {
226  --ptr_;
227  idx_ = kElementsPerStoredItem - 1;
228  }
229  else {
230  --idx_;
231  }
232  return *this;
233  }
234 
236  iterator operator++(int) {
237  iterator ret(*this);
238  ++idx_;
239  if (idx_ == kElementsPerStoredItem) {
240  ++ptr_;
241  idx_ = 0;
242  }
243  return ret;
244  }
245 
247  iterator operator--(int) {
248  iterator ret(*this);
249  if (!idx_) {
250  --ptr_;
251  idx_ = kElementsPerStoredItem - 1;
252  }
253  else {
254  --idx_;
255  }
256  return ret;
257  }
258 
260  reference operator*() const {
261  return reference(ptr_, idx_);
262  }
263 
265  bool operator==(iterator const &other) const {
266  return ptr_ == other.ptr_ && idx_ == other.idx_;
267  }
268 
270  bool operator!=(iterator const &other) const {
271  return !(*this == other);
272  }
273  };
274 
276  class const_iterator {
277 
279  Storage const *ptr_;
280 
282  int idx_;
283 
284  public:
285 
287  const_iterator(): ptr_(nullptr), idx_(0) { }
288 
290  const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
291 
293  iterator &operator++() {
294  ++idx_;
295  if (idx_ == kElementsPerStoredItem) {
296  ++ptr_;
297  idx_ = 0;
298  }
299  return *this;
300  }
301 
303  iterator &operator--() {
304  if (!idx_) {
305  --ptr_;
306  idx_ = kElementsPerStoredItem - 1;
307  }
308  else {
309  --idx_;
310  }
311  return *this;
312  }
313 
315  iterator operator++(int) {
316  iterator ret(*this);
317  ++idx_;
318  if (idx_ == kElementsPerStoredItem) {
319  ++ptr_;
320  idx_ = 0;
321  }
322  return ret;
323  }
324 
326  iterator operator--(int) {
327  iterator ret(*this);
328  if (!idx_) {
329  --ptr_;
330  idx_ = kElementsPerStoredItem - 1;
331  }
332  else {
333  --idx_;
334  }
335  return ret;
336  }
337 
339  const_reference operator*() const {
340  return const_reference(ptr_, idx_);
341  }
342 
344  bool operator==(iterator const &other) const {
345  return ptr_ == other.ptr_ && idx_ == other.idx_;
346  }
347 
349  bool operator!=(iterator const &other) const {
350  return !(*this == other);
351  }
352  };
353 
355  class reverse_iterator {
356 
358  Storage *ptr_;
359 
361  int idx_;
362 
363  public:
364 
366  reverse_iterator(): ptr_(nullptr), idx_(0) { }
367 
369  reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
370 
371  // TODO
372  };
373 
375  class const_reverse_iterator {
376 
378  Storage const *ptr_;
379 
381  int idx_;
382 
383  public:
384 
386  const_reverse_iterator(): ptr_(nullptr), idx_(0) { }
387 
389  const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { }
390 
391  // TODO
392  };
393 
394 private:
395 
397  Storage storage[kStorageElements];
398 
399 public:
400 
402  Array() { }
403 
405  Array(Array const &x) {
407  for (int i = 0; i < int(kStorageElements); ++i) {
408  storage[i] = x.storage[i];
409  }
410  }
411 
414  void clear() {
415 
417  for (int i = 0; i < int(kStorageElements); ++i) {
418  storage[i] = Storage(0);
419  }
420  }
421 
423  reference at(size_type pos) {
424  return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
425  }
426 
428  const_reference at(size_type pos) const {
429  return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem);
430  }
431 
433  reference operator[](size_type pos) {
434  return at(pos);
435  }
436 
438  const_reference operator[](size_type pos) const {
439  return at(pos);
440  }
441 
443  reference front() {
444  return at(0);
445  }
446 
448  const_reference front() const {
449  return at(0);
450  }
451 
453  reference back() {
454  return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
455  }
456 
458  const_reference back() const {
459  return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1);
460  }
461 
464  return reinterpret_cast<pointer>(storage);
465  }
466 
468  const_pointer data() const {
469  return reinterpret_cast<const_pointer>(storage);
470  }
471 
474  return storage;
475  }
476 
478  Storage const * raw_data() const {
479  return storage;
480  }
481 
482 
484  constexpr bool empty() const {
485  return !kElements;
486  }
487 
490  return kElements;
491  }
492 
495  return kElements;
496  }
497 
499  void fill(T const &value) {
500  // TODO
501  }
502 
504  iterator begin() {
505  return iterator(storage);
506  }
507 
509  const_iterator cbegin() const {
510  return const_iterator(storage);
511  }
512 
514  iterator end() {
515  return iterator(storage + kStorageElements);
516  }
517 
519  const_iterator cend() const {
520  return const_iterator(storage + kStorageElements);
521  }
522 
524  reverse_iterator rbegin() {
525  return reverse_iterator(storage + kStorageElements);
526  }
527 
529  const_reverse_iterator crbegin() const {
530  return const_reverse_iterator(storage + kStorageElements);
531  }
532 
534  reverse_iterator rend() {
535  return reverse_iterator(storage);
536  }
537 
539  const_reverse_iterator crend() const {
540  return const_reverse_iterator(storage);
541  }
542 
543  //
544  // Comparison operators
545  //
546 
547 };
548 
550 
551 } // namespace cutlass
552 
CUTLASS_HOST_DEVICE const_reference(Storage const *ptr, int idx=0)
Ctor.
Definition: array_subbyte.h:164
CUTLASS_HOST_DEVICE const_reference at(size_type pos) const
Definition: array_subbyte.h:428
CUTLASS_HOST_DEVICE const_reference back() const
Definition: array_subbyte.h:458
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE reverse_iterator()
Definition: array_subbyte.h:366
#define constexpr
Definition: platform.h:137
CUTLASS_HOST_DEVICE iterator operator--(int)
Definition: array_subbyte.h:326
CUTLASS_HOST_DEVICE Array(Array const &x)
Definition: array_subbyte.h:405
CUTLASS_HOST_DEVICE reference operator[](size_type pos)
Definition: array_subbyte.h:433
CUTLASS_HOST_DEVICE const_reverse_iterator crend() const
Definition: array_subbyte.h:539
CUTLASS_HOST_DEVICE reverse_iterator rbegin()
Definition: array_subbyte.h:524
CUTLASS_HOST_DEVICE bool operator==(iterator const &other) const
Definition: array_subbyte.h:265
size_t size_type
Definition: array_subbyte.h:84
CUTLASS_HOST_DEVICE iterator & operator++()
Definition: array_subbyte.h:214
CUTLASS_HOST_DEVICE const_reverse_iterator crbegin() const
Definition: array_subbyte.h:529
CUTLASS_HOST_DEVICE reverse_iterator rend()
Definition: array_subbyte.h:534
CUTLASS_HOST_DEVICE constexpr bool empty() const
Definition: array_subbyte.h:484
CUTLASS_HOST_DEVICE const_iterator(Storage const *ptr, int idx=0)
Definition: array_subbyte.h:290
ptrdiff_t difference_type
Definition: array_subbyte.h:85
CUTLASS_HOST_DEVICE const_pointer data() const
Definition: array_subbyte.h:468
CUTLASS_HOST_DEVICE reference operator*() const
Definition: array_subbyte.h:260
C++ features that may be otherwise unimplemented for CUDA device functions.
CUTLASS_HOST_DEVICE reference()
Default ctor.
Definition: array_subbyte.h:105
CUTLASS_HOST_DEVICE const_iterator cend() const
Definition: array_subbyte.h:519
CUTLASS_HOST_DEVICE reference(Storage *ptr, int idx=0)
Ctor.
Definition: array_subbyte.h:109
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE Storage * raw_data()
Definition: array_subbyte.h:473
CUTLASS_HOST_DEVICE const_reference()
Default ctor.
Definition: array_subbyte.h:160
CUTLASS_HOST_DEVICE reference back()
Definition: array_subbyte.h:453
CUTLASS_HOST_DEVICE iterator begin()
Definition: array_subbyte.h:504
Defines the size of an element in bits.
Definition: numeric_types.h:42
CUTLASS_HOST_DEVICE void fill(T const &value)
Definition: array_subbyte.h:499
#define nullptr
nullptr
Definition: platform.h:144
CUTLASS_HOST_DEVICE iterator & operator--()
Definition: array_subbyte.h:303
CUTLASS_HOST_DEVICE Array()
Definition: array_subbyte.h:402
T Element
Element type.
Definition: array_subbyte.h:65
CUTLASS_HOST_DEVICE const_reference operator*() const
Definition: array_subbyte.h:339
CUTLASS_HOST_DEVICE iterator operator--(int)
Definition: array_subbyte.h:247
CUTLASS_HOST_DEVICE reference & operator=(T x)
Assignment.
Definition: array_subbyte.h:113
CUTLASS_HOST_DEVICE iterator end()
Definition: array_subbyte.h:514
CUTLASS_HOST_DEVICE reverse_iterator(Storage *ptr, int idx=0)
Definition: array_subbyte.h:369
CUTLASS_HOST_DEVICE pointer data()
Definition: array_subbyte.h:463
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
std::conditional (true specialization)
Definition: platform.h:325
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_HOST_DEVICE reference front()
Definition: array_subbyte.h:443
CUTLASS_HOST_DEVICE const_reverse_iterator(Storage const *ptr, int idx=0)
Definition: array_subbyte.h:389
CUTLASS_HOST_DEVICE const_iterator cbegin() const
Definition: array_subbyte.h:509
CUTLASS_HOST_DEVICE const_iterator()
Definition: array_subbyte.h:287
CUTLASS_HOST_DEVICE reference at(size_type pos)
Definition: array_subbyte.h:423
CUTLASS_HOST_DEVICE const_reverse_iterator()
Definition: array_subbyte.h:386
CUTLASS_HOST_DEVICE iterator operator++(int)
Definition: array_subbyte.h:315
CUTLASS_HOST_DEVICE Storage const * raw_data() const
Definition: array_subbyte.h:478
value_type const * const_pointer
Definition: array_subbyte.h:87
CUTLASS_HOST_DEVICE constexpr size_type size() const
Definition: array_subbyte.h:489
CUTLASS_HOST_DEVICE constexpr size_type max_size() const
Definition: array_subbyte.h:494
CUTLASS_HOST_DEVICE void clear()
Efficient clear method.
Definition: array_subbyte.h:414
CUTLASS_HOST_DEVICE const_reference front() const
Definition: array_subbyte.h:448
CUTLASS_HOST_DEVICE iterator()
Definition: array_subbyte.h:208
CUTLASS_HOST_DEVICE iterator(Storage *ptr, int idx=0)
Definition: array_subbyte.h:211
typename platform::conditional< ((kSizeBits%32)!=0), typename platform::conditional< ((kSizeBits%16)!=0), uint8_t, uint16_t >::type, uint32_t >::type Storage
Storage type.
Definition: array_subbyte.h:62
CUTLASS_HOST_DEVICE iterator operator++(int)
Definition: array_subbyte.h:236
CUTLASS_HOST_DEVICE iterator & operator--()
Definition: array_subbyte.h:224
CUTLASS_HOST_DEVICE bool operator!=(iterator const &other) const
Definition: array_subbyte.h:349
CUTLASS_HOST_DEVICE bool operator!=(iterator const &other) const
Definition: array_subbyte.h:270
T value_type
Definition: array_subbyte.h:83
value_type * pointer
Definition: array_subbyte.h:86
CUTLASS_HOST_DEVICE bool operator==(iterator const &other) const
Definition: array_subbyte.h:344
CUTLASS_HOST_DEVICE iterator & operator++()
Definition: array_subbyte.h:293
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE const_reference operator[](size_type pos) const
Definition: array_subbyte.h:438