CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_coord.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/cutlass.h"
31 #include "cutlass/coord.h"
32 
33 namespace cutlass {
34 
36 
38 struct Tensor4DCoord : public Coord<4> {
39 
41  using Base = Coord<4>;
42 
44  using Index = typename Base::Index;
45 
47  using LongIndex = typename Base::LongIndex;
48 
50  static int const kN = 0;
51 
53  static int const kH = 1;
54 
56  static int const kW = 2;
57 
59  static int const kC = 3;
60 
61  //
62  // Methods
63  //
64 
68 
71  Tensor4DCoord(Coord<4> const &coord): Base(coord) { }
72 
76 
79  Index const & n() const { return this->at(kN); }
80 
83  Index & n() { return this->at(kN); }
84 
87  Index const & h() const { return this->at(kH); }
88 
91  Index & h() { return this->at(kH); }
92 
95  Index const & w() const { return this->at(kW); }
96 
99  Index & w() { return this->at(kW); }
100 
103  Index const & c() const { return this->at(kC); }
104 
107  Index & c() { return this->at(kC); }
108 
109  //
110  // Coord operators
111  //
112 
115  Tensor4DCoord operator+(Base const& b) const {
116  return Tensor4DCoord(Base::operator+(b));
117  }
118 
121  Tensor4DCoord operator-(Base const& b) const {
122  return Tensor4DCoord(Base::operator-(b));
123  }
124 
127  Tensor4DCoord operator*(Base const& b) const {
128  return Tensor4DCoord(Base::operator*(b));
129  }
130 
133  Tensor4DCoord operator/(Base const& b) const {
134  return Tensor4DCoord(Base::operator/(b));
135  }
136 
140  Base::operator+=(b);
141  return *this;
142  }
143 
147  Base::operator-=(b);
148  return *this;
149  }
150 
154  Base::operator*=(b);
155  return *this;
156  }
157 
161  Base::operator/=(b);
162  return *this;
163  }
164 };
165 
167 
168 } // namespace cutlass
CUTLASS_HOST_DEVICE Index & n()
Returns the batch of the coordinate.
Definition: tensor_coord.h:83
Defines a canonical 4D coordinate used by tensor operations.
Definition: tensor_coord.h:38
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE Tensor4DCoord operator/(Base const &b) const
Element-wise division.
Definition: tensor_coord.h:133
CUTLASS_HOST_DEVICE Coord & operator*=(Coord const &b)
In-place multiplication.
Definition: coord.h:222
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
CUTLASS_HOST_DEVICE Coord & operator-=(Coord const &b)
In-place subtraction.
Definition: coord.h:213
CUTLASS_HOST_DEVICE Index & w()
Returns the column of the coordinate.
Definition: tensor_coord.h:99
CUTLASS_HOST_DEVICE Index const & w() const
Returns the column of the coordinate.
Definition: tensor_coord.h:95
int Index
Index type used to store elements.
Definition: coord.h:55
CUTLASS_HOST_DEVICE Tensor4DCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: tensor_coord.h:127
CUTLASS_HOST_DEVICE Tensor4DCoord()
Default ctor.
Definition: tensor_coord.h:67
CUTLASS_HOST_DEVICE Index const & c() const
Returns the channel of the coordinate.
Definition: tensor_coord.h:103
static int const kC
Channels dimension.
Definition: tensor_coord.h:59
CUTLASS_HOST_DEVICE Index & c()
Returns the channel of the coordinate.
Definition: tensor_coord.h:107
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
typename Base::LongIndex LongIndex
LongIndex type.
Definition: tensor_coord.h:47
CUTLASS_HOST_DEVICE Coord & operator/=(Coord const &b)
In-place division.
Definition: coord.h:231
CUTLASS_HOST_DEVICE Tensor4DCoord & operator/=(Base const &b)
In-place division.
Definition: tensor_coord.h:160
CUTLASS_HOST_DEVICE Tensor4DCoord & operator-=(Base const &b)
In-place subtraction.
Definition: tensor_coord.h:146
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:43
CUTLASS_HOST_DEVICE Index const & n() const
Returns the batch of the coordinate.
Definition: tensor_coord.h:79
CUTLASS_HOST_DEVICE Tensor4DCoord(Coord< 4 > const &coord)
Constructs from Coord<4>
Definition: tensor_coord.h:71
CUTLASS_HOST_DEVICE Index & h()
Returns the row of the coordinate.
Definition: tensor_coord.h:91
static int const kN
Batch dimension.
Definition: tensor_coord.h:50
CUTLASS_HOST_DEVICE Tensor4DCoord(Index n, Index h, Index w, Index c)
Helper to construct from N, H, W, and C.
Definition: tensor_coord.h:75
static int const kW
Width dimension.
Definition: tensor_coord.h:56
CUTLASS_HOST_DEVICE Tensor4DCoord & operator+=(Base const &b)
In-place addition.
Definition: tensor_coord.h:139
CUTLASS_HOST_DEVICE Coord & operator+=(Coord const &b)
In-place addition.
Definition: coord.h:204
CUTLASS_HOST_DEVICE Index const & h() const
Returns the row of the coordinate.
Definition: tensor_coord.h:87
int64_t LongIndex
Type used to represent linear offsets.
Definition: coord.h:58
CUTLASS_HOST_DEVICE Tensor4DCoord & operator*=(Base const &b)
In-place multiplication.
Definition: tensor_coord.h:153
CUTLASS_HOST_DEVICE Index & at()
Gets the index of a given Coord element.
Definition: coord.h:255
CUTLASS_HOST_DEVICE Tensor4DCoord operator+(Base const &b) const
Element-wise addition.
Definition: tensor_coord.h:115
Basic include for CUTLASS.
static int const kH
Height dimension.
Definition: tensor_coord.h:53
typename Base::Index Index
Index type.
Definition: tensor_coord.h:44
CUTLASS_HOST_DEVICE Tensor4DCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: tensor_coord.h:121