CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
thread/matrix.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/matrix_coord.h"
34 
35 namespace cutlass {
36 namespace thread {
37 
39 
41 template <
42  typename Element,
43  int Rows,
44  int Columns,
45  typename Layout = layout::RowMajor
46 >
47 class Matrix : public Array<Element, Rows * Columns> {
48 public:
49 
50  // Verify layout refers to a rank=2 matrix.
52  Layout::kRank == 2,
53  "Layout type must refer to a rank=2 matrix");
54 
56  using Base = Array<Element, Rows * Columns>;
57 
59  using Element = Element_;
60 
62  static int const kRows = Rows;
63 
65  static int const kColumns = Columns;
66 
68  using Layout = Layout_;
69 
71  using Reference = Element &;
72 
74  static int const kRank = 2;
75 
77  using Index = typename Layout::Index;
78 
80  using LongIndex = typename Layout::LongIndex;
81 
83  using TensorCoord = typename Layout::TensorCoord;
84 
86  using Stride = typename Layout::Stride;
87 
90 
93 
96 
99 
101  using Diagonal = Vector<Element, __NV_STD_MIN(kRows, kColumns)>;
102 
103 private:
104 
105 
106 public:
107 
108  //
109  // Methods
110  //
111 
114  static MatrixCoord extent() {
115  return make_Coord(kRows, kColumns);
116  }
117 
120  static Layout layout() {
121  return Layout::packed(extent());
122  }
123 
126  Matrix() { }
127 
130  Matrix(Diagonal const &diag) {
131  // Todo - construct from diagonal
132  }
133 
137  return TensorRef(this->data(), layout());
138  }
139 
143  return ConstTensorRef(this->data(), layout());
144  }
145 
149  return TensorView(ref(), extent());
150  }
151 
155  return ConstTensorView(const_ref(), extent());
156  }
157 
160  Reference at(MatrixCoord const& coord) const {
161  typename Base::size_type offset_(layout().offset(coord));
162  return Base::at(offset_);
163  }
164 
167  LongIndex capacity() const {
168  return LongIndex(Base::size());
169  }
170 };
171 
173 
175 template <
176  typename Element,
177  int Rows,
178  typename Layout = layout::ColumnMajor
179 >
181 
183 template <
184  typename Element,
185  int Columns,
186  typename Layout = layout::RowMajor
187 >
189 
191 
192 } // namespace thread
193 } // namespace cutlass
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: thread/matrix.h:83
Per-thread matrix object storing a packed matrix.
Definition: thread/matrix.h:47
Definition: aligned_buffer.h:35
typename Layout::Stride Stride
Stride type.
Definition: thread/matrix.h:86
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:387
Array< Element, Rows *Columns > Base
Base type.
Definition: thread/matrix.h:56
CUTLASS_HOST_DEVICE TensorRef ref()
Returns a TensorRef pointing to the first element of the tensor.
Definition: thread/matrix.h:136
Vector< Element, __NV_STD_MIN(kRows, kColumns)> Diagonal
Diagonal vector.
Definition: thread/matrix.h:101
static CUTLASS_HOST_DEVICE MatrixCoord extent()
Returns the size of the object.
Definition: thread/matrix.h:114
static int const kRows
Number of rows.
Definition: thread/matrix.h:62
CUTLASS_HOST_DEVICE LongIndex capacity() const
Returns the number of scalar elements needed to store tensor.
Definition: thread/matrix.h:167
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data.
Definition: tensor_ref.h:179
typename TensorRef::ConstTensorRef ConstTensorRef
TensorRef to constant matrix object.
Definition: thread/matrix.h:92
Mapping function for column-major matrices.
Definition: layout/matrix.h:142
TensorView< Element, kRank, Layout > TensorView
TensorRef to matrix object.
Definition: thread/matrix.h:95
CUTLASS_HOST_DEVICE TensorView view()
Returns a TensorRef pointing to the first element of the tensor.
Definition: thread/matrix.h:148
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
TensorView< typename platform::remove_const< Element >::type const, Layout > ConstTensorView
TensorView pointing to constant memory.
Definition: tensor_view.h:95
CUTLASS_HOST_DEVICE Reference at(MatrixCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: thread/matrix.h:160
TensorRef< Element, kRank, Layout > TensorRef
TensorRef to matrix object.
Definition: thread/matrix.h:89
typename Layout::Index Index
Index type.
Definition: thread/matrix.h:77
typename TensorView::ConstTensorView ConstTensorView
TensorRef to constant matrix object.
Definition: thread/matrix.h:98
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: thread/matrix.h:80
Element_ Element
Element type.
Definition: thread/matrix.h:59
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
#define static_assert(__e, __m)
Definition: platform.h:153
Mapping function for row-major matrices.
Definition: layout/matrix.h:50
CUTLASS_HOST_DEVICE Matrix()
Ctor.
Definition: thread/matrix.h:126
Defines a canonical coordinate for rank=2 matrices offering named indices.
CUTLASS_HOST_DEVICE Matrix(Diagonal const &diag)
Ctor.
Definition: thread/matrix.h:130
Element & Reference
Reference type to an element.
Definition: thread/matrix.h:71
CUTLASS_HOST_DEVICE ConstTensorRef const_ref() const
Returns a TensorRef pointing to the first element of the tensor.
Definition: thread/matrix.h:142
static CUTLASS_HOST_DEVICE Layout layout()
Returns the layout object.
Definition: thread/matrix.h:120
Basic include for CUTLASS.
Definition: matrix_coord.h:39
CUTLASS_HOST_DEVICE ConstTensorView const_view() const
Returns a TensorView to const data.
Definition: thread/matrix.h:154
static int const kRank
Logical rank of tensor index space.
Definition: thread/matrix.h:74
static int const kColumns
Number of columns.
Definition: thread/matrix.h:65