CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
predicate_vector.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #if !defined(__CUDACC_RTC__)
32 #include <assert.h>
33 #endif
34 #include <stdint.h>
35 
36 #include "cutlass/cutlass.h"
37 
39 
40 namespace cutlass {
41 
43 
60 
80 
96 
99 template <
101  int kPredicates_,
103  int kPredicatesPerByte_ = 4,
105  int kPredicateStart_ = 0>
108  static int const kPredicates = kPredicates_;
109 
111  static int const kPredicatesPerByte = kPredicatesPerByte_;
112 
114  static int const kPredicateStart = kPredicateStart_;
115 
116  // Make sure no one tries to put more than 8 bits in a byte :)
117  static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
118  // Make sure the "offsetted" bits fit in one byte.
119  static_assert(kPredicateStart + kPredicatesPerByte <= 8,
120  "The offsetted predicates must fit within an actual byte.");
121 
123  typedef uint32_t Storage;
124 
126  static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
127 
129  static int const kWordCount = (kBytes + sizeof(Storage) - 1) / sizeof(Storage);
130 
131  private:
132  //
133  // Data members
134  //
135 
137  Storage storageData[kWordCount];
138 
139  //
140  // Methods
141  //
142 
144  CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const {
145  CUTLASS_ASSERT(idx < kPredicates);
146 
147  int byte = (idx / kPredicatesPerByte);
148  int bit_offset = (idx % kPredicatesPerByte);
149 
150  word = byte / sizeof(Storage);
151  int byte_offset = (byte % sizeof(Storage));
152 
153  bit = byte_offset * 8 + bit_offset + kPredicateStart;
154  }
155 
157  CUTLASS_HOST_DEVICE Storage &storage(int word) {
158  CUTLASS_ASSERT(word < kWordCount);
159  return storageData[word];
160  }
161 
163  CUTLASS_HOST_DEVICE Storage const &storage(int word) const {
164  CUTLASS_ASSERT(word < kWordCount);
165  return storageData[word];
166  }
167 
168  public:
169  //
170  // Iterator
171  //
172 
178  class Iterator {
180  PredicateVector &vec_;
181 
183  int bit_;
184 
185  public:
188  Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
189 
192  Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {}
193 
197  ++bit_;
198  return *this;
199  }
200 
203  Iterator &operator+=(int offset) {
204  bit_ += offset;
205  return *this;
206  }
207 
211  --bit_;
212  return *this;
213  }
214 
217  Iterator &operator-=(int offset) {
218  bit_ -= offset;
219  return *this;
220  }
221 
225  Iterator ret(*this);
226  ret.bit_++;
227  return ret;
228  }
229 
233  Iterator ret(*this);
234  ret.bit_--;
235  return ret;
236  }
237 
240  Iterator operator+(int offset) {
241  Iterator ret(*this);
242  ret.bit_ += offset;
243  return ret;
244  }
245 
248  Iterator operator-(int offset) {
249  ConstIterator ret(*this);
250  ret.bit_ -= offset;
251  return ret;
252  }
253 
256  bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
257 
260  bool operator!=(Iterator const &it) const { return bit_ != it.bit_; }
261 
264  bool get() { return vec_.at(bit_); }
265 
268  bool at() const { return vec_.at(bit_); }
269 
272  bool operator*() const { return at(); }
273 
276  void set(bool value = true) { vec_.set(bit_, value); }
277  };
278 
286  PredicateVector const &vec_;
287 
289  int bit_;
290 
291  public:
294  ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
295 
298  ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {}
299 
303  ++bit_;
304  return *this;
305  }
306 
309  ConstIterator &operator+=(int offset) {
310  bit_ += offset;
311  return *this;
312  }
313 
317  --bit_;
318  return *this;
319  }
320 
323  ConstIterator &operator-=(int offset) {
324  bit_ -= offset;
325  return *this;
326  }
327 
331  ConstIterator ret(*this);
332  ret.bit_++;
333  return ret;
334  }
335 
339  ConstIterator ret(*this);
340  ret.bit_--;
341  return ret;
342  }
343 
346  ConstIterator operator+(int offset) {
347  ConstIterator ret(*this);
348  ret.bit_ += offset;
349  return ret;
350  }
351 
354  ConstIterator operator-(int offset) {
355  ConstIterator ret(*this);
356  ret.bit_ -= offset;
357  return ret;
358  }
359 
362  bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
363 
366  bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; }
367 
370  bool get() { return vec_.at(bit_); }
371 
374  bool at() const { return vec_.at(bit_); }
375 
378  bool operator*() const { return at(); }
379  };
380 
386 
389  TrivialIterator(Iterator const &it) {}
390 
394 
397  TrivialIterator &operator++() { return *this; }
398 
401  TrivialIterator operator++(int) { return *this; }
402 
405  bool operator*() const { return true; }
406  };
407 
408  public:
409  //
410  // Methods
411  //
412 
414  CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }
415 
417  CUTLASS_HOST_DEVICE void fill(bool value = true) {
418  Storage item = (value ? ~Storage(0) : Storage(0));
419 
421  for (int i = 0; i < kWordCount; ++i) {
422  storage(i) = item;
423  }
424  }
425 
429  for (int i = 0; i < kWordCount; ++i) {
430  storage(i) = 0;
431  }
432  }
433 
437  for (int i = 0; i < kWordCount; ++i) {
438  storage(i) = ~Storage(0);
439  }
440  }
441 
443  CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }
444 
446  CUTLASS_HOST_DEVICE bool at(int idx) const {
447  int bit, word;
448  computeStorageOffset(word, bit, idx);
449 
450  return ((storage(word) >> bit) & 1);
451  }
452 
454  CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {
455  int bit, word;
456  computeStorageOffset(word, bit, idx);
457 
458  Storage disable_mask = (~(Storage(1) << bit));
459  Storage enable_mask = (Storage(value) << bit);
460 
461  storage(word) = ((storage(word) & disable_mask) | enable_mask);
462  }
463 
467  for (int i = 0; i < kWordCount; ++i) {
468  storage(i) = (storage(i) & predicates.storage(i));
469  }
470  return *this;
471  }
472 
476  for (int i = 0; i < kWordCount; ++i) {
477  storage(i) = (storage(i) | predicates.storage(i));
478  }
479  return *this;
480  }
481 
484  Storage mask(0);
485  for (int byte = 0; byte < sizeof(Storage); ++byte) {
486  Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
487  mask |= (byte_mask << (byte * 8));
488  }
489  uint32_t result = 0;
490  for (int word = 0; word < kWordCount; ++word) {
491  result |= storage(word);
492  }
493  return result == 0;
494  }
495 
497  CUTLASS_DEVICE
498  Iterator begin() { return Iterator(*this); }
499 
501  CUTLASS_DEVICE
502  Iterator end() { return Iterator(*this, kPredicates); }
503 
505  CUTLASS_DEVICE
506  ConstIterator const_begin() const { return ConstIterator(*this); }
507 
509  CUTLASS_DEVICE
510  ConstIterator const_end() const { return ConstIterator(*this, kPredicates); }
511 };
512 
514 
515 } // namespace cutlass
CUTLASS_HOST_DEVICE PredicateVector & operator|=(PredicateVector const &predicates)
Computes the union of two identical predicate vectors.
Definition: predicate_vector.h:474
CUTLASS_HOST_DEVICE TrivialIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:397
CUTLASS_HOST_DEVICE ConstIterator(PredicateVector const &vec, int _start=0)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:298
Definition: aligned_buffer.h:35
uint32_t Storage
Storage type of individual elements.
Definition: predicate_vector.h:117
CUTLASS_HOST_DEVICE TrivialIterator(PredicateVector const &_vec)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:393
CUTLASS_HOST_DEVICE bool is_zero() const
Returns true if entire predicate array is zero.
Definition: predicate_vector.h:483
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:316
static int const kBytes
Number of bytes needed.
Definition: predicate_vector.h:126
CUTLASS_DEVICE ConstIterator const_end() const
Returns a ConstIterator.
Definition: predicate_vector.h:510
CUTLASS_HOST_DEVICE bool at() const
Gets the bit at the pointed to location.
Definition: predicate_vector.h:374
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:302
CUTLASS_HOST_DEVICE void enable()
Sets all predicates to true.
Definition: predicate_vector.h:435
CUTLASS_HOST_DEVICE bool operator==(ConstIterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:362
CUTLASS_HOST_DEVICE bool operator[](int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:443
CUTLASS_HOST_DEVICE bool operator==(Iterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:256
CUTLASS_HOST_DEVICE bool operator!=(Iterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:260
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:446
CUTLASS_HOST_DEVICE Iterator(PredicateVector &vec, int _start=0)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:192
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:330
CUTLASS_HOST_DEVICE Iterator operator++(int)
Post-increment.
Definition: predicate_vector.h:224
CUTLASS_HOST_DEVICE TrivialIterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:389
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:272
C++ features that may be otherwise unimplemented for CUDA device functions.
Iterator that always returns true.
Definition: predicate_vector.h:382
CUTLASS_HOST_DEVICE TrivialIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:401
CUTLASS_HOST_DEVICE ConstIterator operator+(int offset)
Iterator advances by some amount.
Definition: predicate_vector.h:346
CUTLASS_HOST_DEVICE ConstIterator & operator+=(int offset)
Increment.
Definition: predicate_vector.h:309
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_DEVICE ConstIterator const_begin() const
Returns a ConstIterator.
Definition: predicate_vector.h:506
CUTLASS_HOST_DEVICE Iterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:210
CUTLASS_HOST_DEVICE Iterator & operator-=(int offset)
Decrement.
Definition: predicate_vector.h:217
CUTLASS_HOST_DEVICE Iterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:188
CUTLASS_HOST_DEVICE ConstIterator operator-(int offset)
Iterator recedes by some amount.
Definition: predicate_vector.h:354
CUTLASS_HOST_DEVICE Iterator operator+(int offset)
Iterator advances by some amount.
Definition: predicate_vector.h:240
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:417
static int const kPredicates
Number of bits stored by the PredicateVector.
Definition: predicate_vector.h:108
CUTLASS_DEVICE Iterator end()
Returns an iterator.
Definition: predicate_vector.h:502
#define CUTLASS_ASSERT(x)
Definition: cutlass.h:92
CUTLASS_HOST_DEVICE Iterator & operator+=(int offset)
Increment.
Definition: predicate_vector.h:203
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
static int const kPredicatesPerByte
Number of bits stored within each byte of the predicate bit vector.
Definition: predicate_vector.h:111
CUTLASS_HOST_DEVICE PredicateVector & operator&=(PredicateVector const &predicates)
Computes the intersection of two identical predicate vectors.
Definition: predicate_vector.h:465
#define static_assert(__e, __m)
Definition: platform.h:153
Statically sized array of bits implementing.
Definition: predicate_vector.h:106
static int const kWordCount
Number of storage elements needed.
Definition: predicate_vector.h:129
CUTLASS_HOST_DEVICE bool operator!=(ConstIterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:366
CUTLASS_HOST_DEVICE bool at() const
Gets the bit at the pointed to location.
Definition: predicate_vector.h:268
CUTLASS_HOST_DEVICE Iterator & operator++()
Pre-increment.
Definition: predicate_vector.h:196
An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...
Definition: predicate_vector.h:284
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:454
CUTLASS_HOST_DEVICE Iterator operator-(int offset)
Iterator recedes by some amount.
Definition: predicate_vector.h:248
static int const kPredicateStart
First bit withing each byte containing predicates.
Definition: predicate_vector.h:114
CUTLASS_HOST_DEVICE ConstIterator(ConstIterator const &it)
Copy constructor.
Definition: predicate_vector.h:294
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:378
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:338
CUTLASS_HOST_DEVICE void clear()
Clears all predicates.
Definition: predicate_vector.h:427
CUTLASS_HOST_DEVICE PredicateVector(bool value=true)
Initialize the predicate vector.
Definition: predicate_vector.h:414
CUTLASS_DEVICE Iterator begin()
Returns an iterator to the start of the bit vector.
Definition: predicate_vector.h:498
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:405
An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...
Definition: predicate_vector.h:178
CUTLASS_HOST_DEVICE Iterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:232
CUTLASS_HOST_DEVICE ConstIterator & operator-=(int offset)
Decrement.
Definition: predicate_vector.h:323
CUTLASS_HOST_DEVICE TrivialIterator()
Constructor.
Definition: predicate_vector.h:385