CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_view.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
37 #pragma once
38 
39 #if !defined(__CUDACC_RTC__)
40 #include <cmath>
41 #endif
42 
43 #include "cutlass/cutlass.h"
44 #include "cutlass/tensor_ref.h"
45 
46 namespace cutlass {
47 
49 
50 template <
52  typename Element_,
54  typename Layout_
55 >
56 class TensorView : public TensorRef<Element_, Layout_> {
57  public:
58 
61 
63  using Layout = Layout_;
64 
67 
69  using TensorRef = Base;
70 
72  using Element = Element_;
73 
75  using Reference = Element &;
76 
78  static int const kRank = Layout::kRank;
79 
81  using Index = typename Layout::Index;
82 
84  using LongIndex = typename Layout::LongIndex;
85 
87  using TensorCoord = typename Layout::TensorCoord;
88 
90  using Stride = typename Layout::Stride;
91 
96 
99  typename platform::remove_const<Element>::type,
101 
105  static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
106 
107  private:
108 
110  TensorCoord extent_;
111 
112  public:
113 
114  //
115  // Methods
116  //
117 
121 
122  }
123 
127  Element *ptr,
128  Layout const &layout,
129  TensorCoord const &extent
130  ):
131  Base(ptr, layout), extent_(extent) {
132 
133  }
134 
138  TensorRef const &ref,
139  TensorCoord const &extent
140  ):
141  Base(ref), extent_(extent) {
142 
143  }
144 
148  NonConstTensorView const &view
149  ):
150  Base(view), extent_(view.extent_) { }
151 
154  void reset(Element* ptr, Layout const &layout, TensorCoord size) {
155  Base::reset(ptr, layout);
156  this->resize(extent_);
157  }
158 
162  this->extent_ = extent;
163  }
164 
167  TensorCoord const& extent() const { return extent_; }
168 
171  Index extent(int dim) const { return extent_.at(dim); }
172 
175  bool contains(TensorCoord const& coord) const {
177  for (int dim = 0; dim < kRank; ++dim) {
178  if (!(coord[dim] >= 0 && coord[dim] < extent(dim))) {
179  return false;
180  }
181  }
182  return true;
183  }
184 
187  TensorRef ref() const {
188  return TensorRef(this->data(), this->layout());
189  }
190 
194  return ConstTensorRef(this->data(), this->layout());
195  }
196 
200  return ConstTensorView(const_ref(), extent_);
201  }
202 
207  TensorCoord const& location = TensorCoord()
208  ) const {
209 
210  return TensorView(ref(), extent.clamp(extent_ - location)).add_coord_offset(location);
211  }
212 
215  size_t capacity() const {
216  return Base::layout().capacity(extent_);
217  }
218 
222  TensorCoord const& b
223  ) const {
224 
225  TensorView result(*this);
226  result.add_pointer_offset(this->offset(b));
227  return result;
228  }
229 
233  TensorCoord const& b
234  ) {
235 
236  this->add_pointer_offset(this->offset(b));
237  return *this;
238  }
239 
243  TensorCoord const& b
244  ) const {
245 
246  TensorRef result(*this);
247  result.add_pointer_offset(-this->offset(b));
248  return result;
249  }
250 
254  TensorCoord const& b
255  ) {
256 
257  this->add_pointer_offset(-this->offset(b));
258  return *this;
259  }
260 };
261 
263 
265 template <
266  typename Element,
267  typename Layout
268 >
270  Element *ptr,
271  Layout const &layout,
272  typename Layout::TensorCoord const &extent) {
273 
275 }
276 
278 
279 } // namespace cutlass
CUTLASS_HOST_DEVICE TensorView & operator+=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:232
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE size_t capacity() const
Returns the number of scalar elements needed to store tensor.
Definition: tensor_view.h:215
Defines a structure containing strides, bounds, and a pointer to tensor data.
T type
Definition: platform.h:351
CUTLASS_HOST_DEVICE Element * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:254
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
static int const kRank
Logical rank of tensor index space.
Definition: tensor_view.h:78
CUTLASS_HOST_DEVICE void resize(TensorCoord extent)
Changes the size of the view without affecting pointer or layout.
Definition: tensor_view.h:161
CUTLASS_HOST_DEVICE TensorView operator+(TensorCoord const &b) const
Returns a TensorView offset by a given amount.
Definition: tensor_view.h:221
CUTLASS_HOST_DEVICE TensorView & operator-=(TensorCoord const &b)
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:253
CUTLASS_HOST_DEVICE TensorView operator-(TensorCoord const &b) const
Returns a TensorRef offset by a given amount.
Definition: tensor_view.h:242
Base TensorRef
Underlying TensorRef type.
Definition: tensor_view.h:69
TensorRef< typename platform::remove_const< Element >::type const, Layout > ConstTensorRef
TensorRef to constant data.
Definition: tensor_ref.h:179
CUTLASS_HOST_DEVICE TensorRef & add_coord_offset(TensorCoord const &coord)
Adds an offset to each pointer.
Definition: tensor_ref.h:326
Element Element
Data type of individual access.
Definition: tensor_view.h:72
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
TensorView< typename platform::remove_const< Element >::type const, Layout > ConstTensorView
TensorView pointing to constant memory.
Definition: tensor_view.h:95
Definition: tensor_view.h:56
CUTLASS_HOST_DEVICE void reset(Element *ptr, Layout const &layout, TensorCoord size)
Updates the pointer and layout object.
Definition: tensor_view.h:154
typename Layout::TensorCoord TensorCoord
Coordinate in logical tensor space.
Definition: tensor_view.h:87
Element & Reference
Reference type to an element.
Definition: tensor_view.h:75
CUTLASS_HOST_DEVICE void reset(Element *ptr=nullptr)
Updates only the pointer.
Definition: tensor_ref.h:235
Definition: tensor_ref.h:146
CUTLASS_HOST_DEVICE TensorRef ref() const
Returns a TensorRef pointing to the first element of the tensor.
Definition: tensor_view.h:187
typename Layout::Stride Stride
Coordinate in storage n-D array.
Definition: tensor_view.h:90
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:301
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_HOST_DEVICE Index extent(int dim) const
Returns the extent along a particular logical dimension.
Definition: tensor_view.h:171
CUTLASS_HOST_DEVICE TensorView subview(TensorCoord extent, TensorCoord const &location=TensorCoord()) const
Returns a Tensor_view given location and size quantities.
Definition: tensor_view.h:205
CUTLASS_HOST_DEVICE TensorView(TensorRef const &ref, TensorCoord const &extent)
Constructs a TensorView object.
Definition: tensor_view.h:137
CUTLASS_HOST_DEVICE TensorView(NonConstTensorView const &view)
Converting constructor from TensorRef to non-constant data.
Definition: tensor_view.h:147
typename Layout::Index Index
Index type.
Definition: tensor_view.h:81
cutlass::TensorRef< Element_, Layout_ > Base
Base tensor reference.
Definition: tensor_view.h:60
CUTLASS_HOST_DEVICE ConstTensorView const_view() const
Returns a TensorView to const data.
Definition: tensor_view.h:199
typename Base::ConstTensorRef ConstTensorRef
TensorRef pointing to constant memory.
Definition: tensor_view.h:66
CUTLASS_HOST_DEVICE Layout & layout()
Returns the layout object.
Definition: tensor_ref.h:265
CUTLASS_HOST_DEVICE TensorView(Element *ptr, Layout const &layout, TensorCoord const &extent)
Constructs a TensorView object.
Definition: tensor_view.h:126
CUTLASS_HOST_DEVICE TensorView(TensorCoord const &extent=TensorCoord())
Constructs a TensorView object.
Definition: tensor_view.h:120
CUTLASS_HOST_DEVICE TensorView< Element, Layout > make_TensorView(Element *ptr, Layout const &layout, typename Layout::TensorCoord const &extent)
Constructs a TensorRef, deducing types from arguments.
Definition: tensor_view.h:269
CUTLASS_HOST_DEVICE bool contains(TensorCoord const &coord) const
Determines whether a location is within a tensor.
Definition: tensor_view.h:175
CUTLASS_HOST_DEVICE TensorRef & add_pointer_offset(LongIndex offset_)
Adds an offset to each pointer.
Definition: tensor_ref.h:319
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE ConstTensorRef const_ref() const
Returns a TensorRef pointing to the first element of the tensor.
Definition: tensor_view.h:193
typename Layout::LongIndex LongIndex
Long index used for pointer offsets.
Definition: tensor_view.h:84
Layout Layout
Mapping function from logical coordinate to internal n-D array.
Definition: tensor_view.h:63