CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
subbyte_reference.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/numeric_types.h"
31 
32 namespace cutlass {
33 
35 
55 template <
56  typename Element_,
57  typename Storage_ = uint8_t
58 >
61 public:
62 
63  using Element = Element_;
64  using Storage = Storage_;
65  using StoragePointer = Storage const *;
66 
68  "Size of Element must not be greater than Storage.");
69 
71  "Storage must be divisible by Element");
72 
73 private:
74 
76  int const kElementsPerVector = sizeof_bits<Storage>::value / sizeof_bits<Element>::value;
77 
79  Storage const kMask =
82  ~Storage(0));
83 
84 private:
85 
87  StoragePointer ptr_;
88 
92  int offset_;
93 
94 public:
95 
97  ConstSubbyteReference(): ptr_(nullptr), offset_(0) { }
98 
102  Element const *ptr,
103  int64_t offset
104  ):
105  ptr_(reinterpret_cast<StoragePointer>(ptr)),
106  offset_(0) {
107 
108  int64_t offset_in_vectors = offset / kElementsPerVector;
109  int64_t offset_in_elements = offset % kElementsPerVector;
110 
111  ptr_ += offset_in_vectors;
112  offset_ = int(offset_in_elements);
113  }
114 
118  Element *ptr = nullptr
119  ): ConstSubbyteReference(ptr, 0) { }
120 
124  return ptr_;
125  }
126 
129  int element_offset() const {
130  return offset_;
131  }
132 
135  Element get() const {
136  Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits<Element>::value)) & kMask);
137  return reinterpret_cast<Element const &>(item);
138  }
139 
142  operator Element() const {
143  return get();
144  }
145 
149 
150  offset += offset_;
151 
152  int offset_in_vectors = offset / kElementsPerVector;
153  int offset_in_elements = offset % kElementsPerVector;
154 
155  ptr_ += offset_in_vectors;
156  offset_ = offset_in_elements;
157 
158  return *this;
159  }
160 
163  ConstSubbyteReference &operator+=(long long offset) {
164 
165  offset += offset_;
166 
167  long long offset_in_vectors = offset / kElementsPerVector;
168  int offset_in_elements = int(offset % kElementsPerVector);
169 
170  ptr_ += offset_in_vectors;
171  offset_ = offset_in_elements;
172 
173  return *this;
174  }
175 
179 
180  int offset_in_vectors = offset / kElementsPerVector;
181  int offset_in_elements = offset % kElementsPerVector;
182 
183  ptr_ -= offset_in_vectors;
184  offset_ -= offset_in_elements;
185 
186  if (offset_ < 0) {
187  offset_ += kElementsPerVector;
188  --ptr_;
189  }
190 
191  return *this;
192  }
193 
196  ConstSubbyteReference &operator-=(long long offset) {
197 
198  long long offset_in_vectors = offset / kElementsPerVector;
199  int offset_in_elements = int(offset % kElementsPerVector);
200 
201  ptr_ -= offset_in_vectors;
202  offset_ -= offset_in_elements;
203 
204  if (offset_ < 0) {
205  offset_ += kElementsPerVector;
206  --ptr_;
207  }
208 
209  return *this;
210  }
211 
214  ConstSubbyteReference operator+(int offset) const {
215 
216  ConstSubbyteReference ref(ptr_, offset_);
217  ref += offset;
218 
219  return ref;
220  }
221 
224  ConstSubbyteReference operator+(long long offset) const {
225 
226  ConstSubbyteReference ref(ptr_, offset_);
227  ref += offset;
228 
229  return ref;
230  }
231 
234  ConstSubbyteReference operator-(int offset) const {
235 
236  ConstSubbyteReference ref(ptr_, offset_);
237  ref -= offset;
238 
239  return ref;
240  }
241 
244  ConstSubbyteReference operator-=(long long offset) const {
245 
246  ConstSubbyteReference ref(ptr_, offset_);
247  ref -= offset;
248 
249  return ref;
250  }
251 
254  ptrdiff_t operator-(ConstSubbyteReference ref) const {
255  return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
256  }
257 
260  explicit operator int() const {
261  return int(get());
262  }
263 
266  explicit operator int64_t() const {
267  return int64_t(get());
268  }
269 
272  explicit operator uint64_t() const {
273  return uint64_t(get());
274  }
275 
278  explicit operator float() const {
279  return float(get());
280  }
281 
284  explicit operator double() const {
285  return double(get());
286  }
287 };
288 
289 template <
290  typename Element_,
291  typename Storage_ = uint8_t
292 >
295 public:
296 
297  using Element = Element_;
298  using Storage = Storage_;
300 
302  "Size of Element must not be greater than Storage.");
303 
305  "Storage must be divisible by Element");
306 
307 private:
308 
310  int const kElementsPerVector = sizeof_bits<Storage>::value / sizeof_bits<Element>::value;
311 
313  Storage const kMask =
316  ~Storage(0));
317 
318 private:
319 
321  StoragePointer ptr_;
322 
326  int offset_;
327 
328 public:
329 
331  SubbyteReference(): ptr_(nullptr), offset_(0) { }
332 
336  Element *ptr,
337  int64_t offset
338  ):
339  ptr_(reinterpret_cast<StoragePointer>(ptr)),
340  offset_(0) {
341 
342  int64_t offset_in_vectors = offset / kElementsPerVector;
343  int64_t offset_in_elements = offset % kElementsPerVector;
344 
345  ptr_ += offset_in_vectors;
346  offset_ = int(offset_in_elements);
347  }
348 
352  Element *ptr = nullptr
353  ): SubbyteReference(ptr, 0) { }
354 
358  return ptr_;
359  }
360 
363  int element_offset() const {
364  return offset_;
365  }
366 
369  Element get() const {
370  Storage item = Storage((*ptr_ >> (offset_ * sizeof_bits<Element>::value)) & kMask);
371  return reinterpret_cast<Element const &>(item);
372  }
373 
376  SubbyteReference & set(Element const &x) {
377 
378  Storage item = (reinterpret_cast<Storage const &>(x) & kMask);
379 
380  Storage kUpdateMask = Storage(~(kMask << (offset_ * sizeof_bits<Element>::value)));
381  *ptr_ = Storage((*ptr_ & kUpdateMask) | Storage(item << (offset_ * sizeof_bits<Element>::value)));
382 
383  return *this;
384  }
385 
388  operator Element() const {
389  return get();
390  }
391 
395  return set(x);
396  }
397 
401  return set(x.get());
402  }
403 
408  return set(x.get());
409  }
410 
414 
415  offset += offset_;
416 
417  int offset_in_vectors = offset / kElementsPerVector;
418  int offset_in_elements = offset % kElementsPerVector;
419 
420  ptr_ += offset_in_vectors;
421  offset_ = offset_in_elements;
422 
423  return *this;
424  }
425 
428  SubbyteReference &operator+=(long long offset) {
429 
430  offset += offset_;
431 
432  long long offset_in_vectors = offset / kElementsPerVector;
433  int offset_in_elements = int(offset % kElementsPerVector);
434 
435  ptr_ += offset_in_vectors;
436  offset_ = offset_in_elements;
437 
438  return *this;
439  }
440 
444 
445  int offset_in_vectors = offset / kElementsPerVector;
446  int offset_in_elements = offset % kElementsPerVector;
447 
448  ptr_ -= offset_in_vectors;
449  offset_ -= offset_in_elements;
450 
451  if (offset_ < 0) {
452  offset_ += kElementsPerVector;
453  --ptr_;
454  }
455 
456  return *this;
457  }
458 
461  SubbyteReference &operator-=(long long offset) {
462 
463  long long offset_in_vectors = offset / kElementsPerVector;
464  int offset_in_elements = int(offset % kElementsPerVector);
465 
466  ptr_ -= offset_in_vectors;
467  offset_ -= offset_in_elements;
468 
469  if (offset_ < 0) {
470  offset_ += kElementsPerVector;
471  --ptr_;
472  }
473 
474  return *this;
475  }
476 
479  SubbyteReference operator+(int offset) const {
480 
481  SubbyteReference ref(ptr_, offset_);
482  ref += offset;
483 
484  return ref;
485  }
486 
489  SubbyteReference operator+(long long offset) const {
490 
491  SubbyteReference ref(ptr_, offset_);
492  ref += offset;
493 
494  return ref;
495  }
496 
499  SubbyteReference operator-(int offset) const {
500 
501  SubbyteReference ref(ptr_, offset_);
502  ref -= offset;
503 
504  return ref;
505  }
506 
509  SubbyteReference operator-=(long long offset) const {
510 
511  SubbyteReference ref(ptr_, offset_);
512  ref -= offset;
513 
514  return ref;
515  }
516 
519  ptrdiff_t operator-(SubbyteReference ref) const {
520  return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_);
521  }
522 
525  explicit operator int() const {
526  return int(get());
527  }
528 
531  explicit operator int64_t() const {
532  return int64_t(get());
533  }
534 
537  explicit operator uint64_t() const {
538  return uint64_t(get());
539  }
540 
543  explicit operator float() const {
544  return float(get());
545  }
546 
549  explicit operator double() const {
550  return double(get());
551  }
552 };
553 
555 
556 template <typename Element, bool subbyte = (sizeof_bits<Element>::value < 8)>
558 
559 template <typename Element>
560 struct ReferenceFactory<Element, false> {
562  static Element &get(Element *ptr, int64_t offset) {
563  return ptr[offset];
564  }
565 
567  static Element const &get(Element const *ptr, int64_t offset) {
568  return ptr[offset];
569  }
570 };
571 
572 template <typename Element>
573 struct ReferenceFactory<Element, true> {
575  static SubbyteReference<Element> get(Element *ptr, int64_t offset) {
576  return SubbyteReference<Element>(ptr, offset);
577  }
578 
581  int64_t offset) {
582  return ConstSubbyteReference<Element>(ptr, offset);
583  }
584 };
585 
587 
588 } // namespace cutlass
Definition: subbyte_reference.h:60
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE int element_offset() const
Gets element offset within storage vector.
Definition: subbyte_reference.h:363
CUTLASS_HOST_DEVICE SubbyteReference & operator=(Element const &x)
Stores an element to memory.
Definition: subbyte_reference.h:394
CUTLASS_HOST_DEVICE SubbyteReference()
Definition: subbyte_reference.h:331
CUTLASS_HOST_DEVICE SubbyteReference operator-=(long long offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:509
CUTLASS_HOST_DEVICE SubbyteReference(Element *ptr, int64_t offset)
Constructor.
Definition: subbyte_reference.h:335
CUTLASS_HOST_DEVICE SubbyteReference operator+(int offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:479
CUTLASS_HOST_DEVICE ConstSubbyteReference operator-=(long long offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:244
Storage_ Storage
Definition: subbyte_reference.h:64
CUTLASS_HOST_DEVICE ConstSubbyteReference & operator+=(long long offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:163
CUTLASS_HOST_DEVICE SubbyteReference operator+(long long offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:489
CUTLASS_HOST_DEVICE ptrdiff_t operator-(ConstSubbyteReference ref) const
Computes the difference in elements between references.
Definition: subbyte_reference.h:254
CUTLASS_HOST_DEVICE ConstSubbyteReference & operator+=(int offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:148
CUTLASS_HOST_DEVICE SubbyteReference & operator-=(long long offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:461
Defines the size of an element in bits.
Definition: numeric_types.h:42
CUTLASS_HOST_DEVICE Element get() const
Unpacks an element from memory.
Definition: subbyte_reference.h:135
#define nullptr
nullptr
Definition: platform.h:144
Element_ Element
Definition: subbyte_reference.h:297
Definition: subbyte_reference.h:557
CUTLASS_HOST_DEVICE StoragePointer storage_pointer() const
Gets storage pointer.
Definition: subbyte_reference.h:123
CUTLASS_HOST_DEVICE SubbyteReference & operator+=(long long offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:428
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE StoragePointer storage_pointer() const
Gets storage pointer.
Definition: subbyte_reference.h:357
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_HOST_DEVICE ConstSubbyteReference(Element const *ptr, int64_t offset)
Constructor.
Definition: subbyte_reference.h:101
CUTLASS_HOST_DEVICE ConstSubbyteReference operator+(int offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:214
CUTLASS_HOST_DEVICE ConstSubbyteReference & operator-=(int offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:178
CUTLASS_HOST_DEVICE ConstSubbyteReference()
Definition: subbyte_reference.h:97
CUTLASS_HOST_DEVICE ConstSubbyteReference operator+(long long offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:224
CUTLASS_HOST_DEVICE SubbyteReference & operator=(SubbyteReference const &x)
Stores an element to memory.
Definition: subbyte_reference.h:400
Storage_ Storage
Definition: subbyte_reference.h:298
CUTLASS_HOST_DEVICE ConstSubbyteReference(Element *ptr=nullptr)
Constructor.
Definition: subbyte_reference.h:117
Definition: subbyte_reference.h:294
CUTLASS_HOST_DEVICE int element_offset() const
Gets element offset within storage vector.
Definition: subbyte_reference.h:129
CUTLASS_HOST_DEVICE Element get() const
Unpacks an element from memory.
Definition: subbyte_reference.h:369
CUTLASS_HOST_DEVICE SubbyteReference & operator=(ConstSubbyteReference< Element, Storage > const &x)
Stores an element to memory.
Definition: subbyte_reference.h:406
Storage * StoragePointer
Definition: subbyte_reference.h:299
Storage const * StoragePointer
Definition: subbyte_reference.h:65
CUTLASS_HOST_DEVICE SubbyteReference(Element *ptr=nullptr)
Constructor.
Definition: subbyte_reference.h:351
CUTLASS_HOST_DEVICE SubbyteReference & operator-=(int offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:443
Element_ Element
Definition: subbyte_reference.h:63
CUTLASS_HOST_DEVICE ConstSubbyteReference & operator-=(long long offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:196
CUTLASS_HOST_DEVICE SubbyteReference operator-(int offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:499
CUTLASS_HOST_DEVICE ConstSubbyteReference operator-(int offset) const
Returns a reference to an element with a given offset from the current reference. ...
Definition: subbyte_reference.h:234
CUTLASS_HOST_DEVICE ptrdiff_t operator-(SubbyteReference ref) const
Computes the difference in elements between references.
Definition: subbyte_reference.h:519
CUTLASS_HOST_DEVICE SubbyteReference & operator+=(int offset)
Adds an offset in units of elements to the reference.
Definition: subbyte_reference.h:413