CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
host/tensor_compare.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 /* \file
26  \brief Defines host-side elementwise operations on TensorView.
27 */
28 
29 #pragma once
30 
31 // Standard Library includes
32 #include <utility>
33 
34 // Cutlass includes
35 #include "cutlass/cutlass.h"
37 //#include "cutlass/util/type_traits.h"
38 #include "tensor_foreach.h"
39 
40 namespace cutlass {
41 namespace reference {
42 namespace host {
43 
46 
47 namespace detail {
48 
49 template <
50  typename Element,
51  typename Layout>
53 
54  //
55  // Data members
56  //
57 
60  bool result;
61 
63  TensorEqualsFunc(): result(true) { }
64 
67  TensorView<Element, Layout> const &lhs_,
68  TensorView<Element, Layout> const &rhs_
69  ) :
70  lhs(lhs_), rhs(rhs_), result(true) { }
71 
73  void operator()(Coord<Layout::kRank> const &coord) {
74 
75  Element lhs_ = lhs.at(coord);
76  Element rhs_ = rhs.at(coord);
77 
78  if (lhs_ != rhs_) {
79  result = false;
80  }
81  }
82 
84  operator bool() const {
85  return result;
86  }
87 };
88 
89 } // namespace detail
90 
92 
94 template <
95  typename Element,
96  typename Layout>
100 
101  // Extents must be identical
102  if (lhs.extent() != rhs.extent()) {
103  return false;
104  }
105 
108  lhs.extent(),
109  func
110  );
111 
112  return bool(func);
113 }
114 
117 
119 template <
120  typename Element,
121  typename Layout>
125 
126  // Extents must be identical
127  if (lhs.extent() != rhs.extent()) {
128  return true;
129  }
130 
133  lhs.extent(),
134  func
135  );
136 
137  return !bool(func);
138 }
139 
142 
143 namespace detail {
144 
145 template <
146  typename Element,
147  typename Layout>
149 
150  //
151  // Data members
152  //
153 
155  Element value;
156  bool contains;
158 
159  //
160  // Methods
161  //
162 
164  TensorContainsFunc(): contains(false) { }
165 
168  TensorView<Element, Layout> const &view_,
169  Element value_
170  ) :
171  view(view_), value(value_), contains(false) { }
172 
174  void operator()(Coord<Layout::kRank> const &coord) {
175 
176  if (view.at(coord) == value) {
177  if (!contains) {
178  location = coord;
179  }
180  contains = true;
181  }
182  }
183 
185  operator bool() const {
186  return contains;
187  }
188 };
189 
190 } // namespace detail
191 
193 
195 template <
196  typename Element,
197  typename Layout>
199  TensorView<Element, Layout> const & view,
200  Element value) {
201 
203  view,
204  value
205  );
206 
208  view.extent(),
209  func
210  );
211 
212  return bool(func);
213 }
214 
216 
220 template <
221  typename Element,
222  typename Layout>
223 std::pair<bool, Coord<Layout::kRank> > TensorFind(
224  TensorView<Element, Layout> const & view,
225  Element value) {
226 
228  view,
229  value
230  );
231 
233  view.extent(),
234  func
235  );
236 
237  return std::make_pair(bool(func), func.location);
238 }
239 
242 
243 } // namespace host
244 } // namespace reference
245 } // namespace cutlass
TensorContainsFunc()
Ctor.
Definition: host/tensor_compare.h:164
Definition: aligned_buffer.h:35
< Layout function
Definition: host/tensor_compare.h:148
TensorContainsFunc(TensorView< Element, Layout > const &view_, Element value_)
Ctor.
Definition: host/tensor_compare.h:167
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
CUTLASS_HOST_DEVICE std::pair< T1, T2 > make_pair(T1 t, T2 u)
Definition: platform.h:232
void operator()(Coord< Layout::kRank > const &coord)
Visits a coordinate.
Definition: host/tensor_compare.h:174
bool TensorEquals(TensorView< Element, Layout > const &lhs, TensorView< Element, Layout > const &rhs)
Returns true if two tensor views are equal.
Definition: host/tensor_compare.h:97
Coord< Layout::kRank > location
Definition: host/tensor_compare.h:157
TensorView< Element, Layout > view
Definition: host/tensor_compare.h:154
TensorView< Element, Layout > lhs
Definition: host/tensor_compare.h:58
void operator()(Coord< Layout::kRank > const &coord)
Visits a coordinate.
Definition: host/tensor_compare.h:73
bool TensorContains(TensorView< Element, Layout > const &view, Element value)
Returns true if a value is present in a tensor.
Definition: host/tensor_compare.h:198
bool contains
Definition: host/tensor_compare.h:156
TensorEqualsFunc()
Ctor.
Definition: host/tensor_compare.h:63
This header contains a class to parametrize a statistical distribution function.
bool TensorNotEquals(TensorView< Element, Layout > const &lhs, TensorView< Element, Layout > const &rhs)
Returns true if two tensor views are NOT equal.
Definition: host/tensor_compare.h:122
Element value
Definition: host/tensor_compare.h:155
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
TensorEqualsFunc(TensorView< Element, Layout > const &lhs_, TensorView< Element, Layout > const &rhs_)
Ctor.
Definition: host/tensor_compare.h:66
void TensorForEach(Coord< Rank > extent, Func &func)
Iterates over the index space of a tensor.
Definition: host/tensor_foreach.h:87
bool result
Definition: host/tensor_compare.h:60
std::pair< bool, Coord< Layout::kRank > > TensorFind(TensorView< Element, Layout > const &view, Element value)
< Layout function
Definition: host/tensor_compare.h:223
Basic include for CUTLASS.
TensorView< Element, Layout > rhs
Definition: host/tensor_compare.h:59
< Layout function
Definition: host/tensor_compare.h:52