CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
35 #pragma once
36 #include "assert.h"
37 #include "cutlass/cutlass.h"
38 #include "cutlass/fast_math.h"
39 #include "cutlass/layout/matrix.h"
40 #include "cutlass/coord.h"
41 #include "cutlass/tensor_coord.h"
42 
43 namespace cutlass {
44 namespace layout {
45 
47 //
48 // Defines data layouts of various tensor formats usable by TensorRef and other classes.
49 //
51 
53 class TensorNHWC {
54 public:
56  static int const kRank = 4;
57 
59  static int const kStrideRank = 3;
60 
62  using Index = int32_t;
63 
65  using LongIndex = int64_t;
66 
69 
72 
73 private:
74  //
75  // Data members
76  //
77 
79  Stride stride_;
80 
81 public:
82  //
83  // Methods
84  //
85 
88  TensorNHWC(Stride const &stride = Stride(0)): stride_(stride) { }
89 
92  TensorNHWC(typename Stride::Index c, typename Stride::Index wc, typename Stride::Index hwc): stride_(make_Coord(c, wc, hwc)) { }
93 
96  static TensorNHWC packed(TensorCoord const &extent) {
97  return TensorNHWC(
98  make_Coord(
99  extent.c(),
100  extent.w() * extent.c(),
101  extent.h() * extent.w() * extent.c()
102  )
103  );
104  }
105 
108  LongIndex operator()(TensorCoord const &coord) const {
109  return coord.c() +
110  LongIndex(stride_[0] * coord.w()) +
111  LongIndex(stride_[1] * coord.h()) +
112  LongIndex(stride_[2] * coord.n());
113  }
114 
117  explicit operator RowMajor() {
118  return RowMajor(stride_[0]);
119  }
120 
124 
125  int n = 0, h = 0, w = 0, c = 0;
126 
127  #if defined(__CUDA_ARCH__)
128  int tmp = 0;
129  c = int(index % static_cast<int>(stride_[0]));
130 
131  unsigned int hw_mul, hw_shr, w_mul, w_shr, c_mul, c_shr;
132 
133  find_divisor(hw_mul, hw_shr, stride_[2]);
134  find_divisor(w_mul, w_shr, stride_[1]);
135  find_divisor(c_mul, c_shr, stride_[0]);
136 
137  fast_divmod(n, tmp, index, int(stride_[2]), hw_mul, hw_shr);
138  fast_divmod(h, w, tmp, int(stride_[1]), w_mul, w_shr);
139  fast_divmod(w, tmp, w, int(stride_[0]), c_mul, c_shr);
140  #else
141 
142  n = int(index / (stride_[0] * stride_[1] * stride_[2]));
143  LongIndex residual = index % (stride_[0] * stride_[1] * stride_[2]);
144 
145  h = int(residual / (stride_[0] * stride_[1]));
146  residual = (residual % (stride_[0] * stride_[1]));
147 
148  w = int(residual / stride_[0]);
149  c = int(residual % stride_[0]);
150 
151  #endif
152  return TensorCoord(n, h, w, c);
153  }
154 
157  Stride stride() const {
158  return stride_;
159  }
160 
164  return stride_;
165  }
166 
169  LongIndex capacity(TensorCoord const &extent) const {
170  // it does not make sense if the extent is larger than stride
171  // and we could not rely on the capacity calculation in such cases
172  // we could move this checkers to debug code only
173  if ((extent.c() > stride_[0])
174  || (extent.w() * stride_[0] > stride_[1])
175  || (extent.h() * stride_[1] > stride_[2])) {
176  assert(0);
177  }
178  return extent.n() * stride_[2];
179  }
180 };
181 
182 
184 
186 class TensorNCHW {
187 public:
189  static int const kRank = 4;
190 
192  static int const kStrideRank = 3;
193 
195  using Index = int32_t;
196 
198  using LongIndex = int64_t;
199 
202 
205 
206 private:
207  //
208  // Data members
209  //
210 
212  Stride stride_;
213 
214 public:
215  //
216  // Methods
217  //
218 
221  TensorNCHW(Stride const &stride = Stride(0)): stride_(stride) { }
222 
225  static TensorNCHW packed(TensorCoord const &extent) {
226  return TensorNCHW(
227  make_Coord(
228  extent.w(),
229  extent.w() * extent.h(),
230  extent.h() * extent.w() * extent.c()
231  )
232  );
233  }
234 
237  LongIndex operator()(TensorCoord const &coord) const {
238  return coord.w() +
239  LongIndex(stride_[0] * coord.h()) +
240  LongIndex(stride_[1] * coord.c()) +
241  LongIndex(stride_[2] * coord.n());
242  }
243 
246  Stride stride() const {
247  return stride_;
248  }
249 
253  return stride_;
254  }
255 
258  LongIndex capacity(TensorCoord const &extent) const {
259  return extent.n() * stride_[2];
260  }
261 };
262 
264 
266 template <int Interleave>
268 public:
269 
271  static int const kInterleave = Interleave;
272 
274  static int const kRank = 4;
275 
277  static int const kStrideRank = 3;
278 
280  using Index = int32_t;
281 
283  using LongIndex = int64_t;
284 
287 
290 
291 private:
292  //
293  // Data members
294  //
295 
297  Stride stride_;
298 
299 public:
300  //
301  // Methods
302  //
303 
306  TensorNCxHWx(Stride const &stride = Stride(0)): stride_(stride) { }
307 
310  static TensorNCxHWx packed(TensorCoord const &extent) {
311  return TensorNCxHWx(
312  make_Coord(
313  kInterleave * extent.w(),
314  kInterleave * extent.w() * extent.h(),
315  extent.h() * extent.w() * extent.c()
316  )
317  );
318  }
319 
322  LongIndex operator()(TensorCoord const &coord) const {
323 
324  Index c_minor = (coord.c() % kInterleave);
325  Index c_major = (coord.c() / kInterleave);
326 
327  return c_minor +
328  LongIndex(kInterleave * coord.w()) +
329  LongIndex(stride_[0] * coord.h()) +
330  LongIndex(stride_[1] * c_major) +
331  LongIndex(stride_[2] * coord.n());
332  }
333 
336  Stride stride() const {
337  return stride_;
338  }
339 
343  return stride_;
344  }
345 
348  LongIndex capacity(TensorCoord const &extent) const {
349  return extent.n() * stride_[2];
350  }
351 };
352 
354 
356 template <int Interleave>
358 public:
359 
361  static int const kInterleave = Interleave;
362 
364  static int const kRank = 4;
365 
367  static int const kStrideRank = 3;
368 
370  using Index = int32_t;
371 
373  using LongIndex = int64_t;
374 
377 
380 
381 private:
382  //
383  // Data members
384  //
385 
387  Stride stride_;
388 
389 public:
390  //
391  // Methods
392  //
393 
396  TensorCxRSKx(Stride const &stride = Stride(0)): stride_(stride) { }
397 
400  static TensorCxRSKx packed(TensorCoord const &extent) {
401  return TensorCxRSKx(
402  make_Coord(
403  kInterleave * extent.n(),
404  kInterleave * extent.n() * extent.w(),
405  kInterleave * extent.n() * extent.w() * extent.h()
406  )
407  );
408  }
409 
412  LongIndex operator()(TensorCoord const &coord) const {
413 
414  Index c_minor = (coord.c() % kInterleave);
415  Index c_major = (coord.c() / kInterleave);
416 
417  return c_minor +
418  LongIndex(kInterleave * coord.n()) +
419  LongIndex(stride_[0] * coord.w()) +
420  LongIndex(stride_[1] * coord.h()) +
421  LongIndex(stride_[2] * c_major);
422  }
423 
426  Stride stride() const {
427  return stride_;
428  }
429 
433  return stride_;
434  }
435 
438  LongIndex capacity(TensorCoord const &extent) const {
439  return (extent.c() / kInterleave * stride_[2]);
440  }
441 };
442 
444 
445 } // namespace layout
446 } // namespace cutlass
Coord< kStrideRank > Stride
Stride vector.
Definition: tensor.h:71
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: tensor.h:246
Defines a canonical 4D coordinate used by tensor operations.
Definition: tensor_coord.h:38
CUTLASS_HOST_DEVICE TensorCxRSKx(Stride const &stride=Stride(0))
Constructor.
Definition: tensor.h:396
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE void fast_divmod(int &quo, int &rem, int src, int div, unsigned int mul, unsigned int shr)
Definition: fast_math.h:176
CUTLASS_HOST_DEVICE LongIndex capacity(TensorCoord const &extent) const
Compute the number of contiguous elements needed to store a tensor with the given size...
Definition: tensor.h:348
CUTLASS_HOST_DEVICE TensorNCxHWx(Stride const &stride=Stride(0))
Constructor.
Definition: tensor.h:306
static int const kStrideRank
Rank of stride vector.
Definition: tensor.h:59
static CUTLASS_HOST_DEVICE TensorNCxHWx packed(TensorCoord const &extent)
Helper returns a layout to a tightly packed tensor.
Definition: tensor.h:310
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
CUTLASS_HOST_DEVICE LongIndex operator()(TensorCoord const &coord) const
Returns the offset of a coordinate in linear memory.
Definition: tensor.h:412
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: tensor.h:336
CUTLASS_HOST_DEVICE LongIndex capacity(TensorCoord const &extent) const
Compute the number of contiguous elements needed to store a tensor with the given size...
Definition: tensor.h:169
int32_t Index
Index type used for coordinates.
Definition: tensor.h:62
Tensor4DCoord TensorCoord
Logical coordinate (n, h, w, c)
Definition: tensor.h:68
Mapping function for 4-D NC/xHWx tensors.
Definition: tensor.h:267
int64_t LongIndex
Long index type used for offsets.
Definition: tensor.h:198
CUTLASS_HOST_DEVICE Index const & w() const
Returns the column of the coordinate.
Definition: tensor_coord.h:95
CUTLASS_HOST_DEVICE TensorNHWC(Stride const &stride=Stride(0))
Constructor.
Definition: tensor.h:88
int Index
Index type used to store elements.
Definition: coord.h:55
static CUTLASS_HOST_DEVICE TensorNCHW packed(TensorCoord const &extent)
Helper returns a layout to a tightly packed tensor.
Definition: tensor.h:225
CUTLASS_HOST_DEVICE Stride & stride()
Returns the stride of the layout.
Definition: tensor.h:163
CUTLASS_HOST_DEVICE TensorNCHW(Stride const &stride=Stride(0))
Constructor.
Definition: tensor.h:221
static int const kRank
Logical rank of tensor.
Definition: tensor.h:56
CUTLASS_HOST_DEVICE Stride & stride()
Returns the stride of the layout.
Definition: tensor.h:342
CUTLASS_HOST_DEVICE LongIndex capacity(TensorCoord const &extent) const
Compute the number of contiguous elements needed to store a tensor with the given size...
Definition: tensor.h:258
CUTLASS_HOST_DEVICE Index const & c() const
Returns the channel of the coordinate.
Definition: tensor_coord.h:103
int32_t Index
Index type used for coordinates.
Definition: tensor.h:370
CUTLASS_HOST_DEVICE TensorNHWC(typename Stride::Index c, typename Stride::Index wc, typename Stride::Index hwc)
Constructor.
Definition: tensor.h:92
CUTLASS_HOST_DEVICE Stride & stride()
Returns the stride of the layout.
Definition: tensor.h:432
Defines a canonical coordinate for rank=4 tensors offering named indices.
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: tensor.h:426
Mapping function for 4-D CxRSKx tensors.
Definition: tensor.h:357
int64_t LongIndex
Long index type used for offsets.
Definition: tensor.h:65
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
static CUTLASS_HOST_DEVICE TensorNHWC packed(TensorCoord const &extent)
Helper returns a layout to a tightly packed NHWC tensor.
Definition: tensor.h:96
Mapping function for 4-D NCHW tensors.
Definition: tensor.h:186
CUTLASS_HOST_DEVICE Stride stride() const
Returns the stride of the layout.
Definition: tensor.h:157
CUTLASS_HOST_DEVICE LongIndex capacity(TensorCoord const &extent) const
Compute the number of contiguous elements needed to store a tensor with the given size...
Definition: tensor.h:438
static CUTLASS_HOST_DEVICE TensorCxRSKx packed(TensorCoord const &extent)
Helper returns a layout to a tightly packed tensor.
Definition: tensor.h:400
int32_t Index
Index type used for coordinates.
Definition: tensor.h:195
CUTLASS_HOST_DEVICE LongIndex operator()(TensorCoord const &coord) const
Returns the offset of a coordinate in linear memory.
Definition: tensor.h:237
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE Index const & n() const
Returns the batch of the coordinate.
Definition: tensor_coord.h:79
CUTLASS_HOST_DEVICE void find_divisor(unsigned int &mul, unsigned int &shr, unsigned int denom)
Definition: fast_math.h:159
int64_t LongIndex
Long index type used for offsets.
Definition: tensor.h:373
CUTLASS_HOST_DEVICE Stride & stride()
Returns the stride of the layout.
Definition: tensor.h:252
Defines layout functions used by TensorRef and derived classes.
CUTLASS_HOST_DEVICE LongIndex operator()(TensorCoord const &coord) const
Returns the offset of a coordinate in linear memory.
Definition: tensor.h:322
Math utilities.
int64_t LongIndex
Long index type used for offsets.
Definition: tensor.h:283
CUTLASS_HOST_DEVICE Index const & h() const
Returns the row of the coordinate.
Definition: tensor_coord.h:87
Mapping function for 4-D NHWC tensors.
Definition: tensor.h:53
int32_t Index
Index type used for coordinates.
Definition: tensor.h:280
CUTLASS_HOST_DEVICE TensorCoord inverse(LongIndex index) const
Returns the logical coordinate (n, h, w, c) from a given offset in linear memory. ...
Definition: tensor.h:123
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE LongIndex operator()(TensorCoord const &coord) const
Returns the offset of a coordinate (n, h, w, c) in linear memory.
Definition: tensor.h:108