CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
host/tensor_elementwise.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 /* \file
26  \brief Defines host-side elementwise operations on TensorView.
27 */
28 
29 #pragma once
30 
31 // Cutlass includes
32 #include "cutlass/cutlass.h"
33 #include "cutlass/functional.h"
34 
35 #include "tensor_foreach.h"
36 
37 namespace cutlass {
38 namespace reference {
39 namespace host {
40 
43 
44 namespace detail {
45 
47 
49 template <
50  typename ElementA,
51  typename LayoutA,
52  typename ElementB,
53  typename LayoutB,
54  typename ElementD,
55  typename LayoutD,
56  typename BinaryFunc>
58 
59  //
60  // Data members
61  //
62 
67  BinaryFunc func;
68 
69  //
70  // Methods
71  //
72 
75 
78  TensorView<ElementD, LayoutD> const & view_d_,
79  TensorRef<ElementA, LayoutA> const & ref_a_,
80  TensorRef<ElementB, LayoutB> const & ref_b_,
81  BinaryFunc func = BinaryFunc()
82  ):
83  view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { }
84 
86  void operator()(Coord<LayoutD::kRank> const &coord) const {
87  view_d.at(coord) = func(
88  ElementD(view_a.at(coord)),
89  ElementD(view_b.at(coord))
90  );
91  }
92 };
93 
94 } // namespace detail
95 
98 
100 template <
101  typename ElementD,
102  typename LayoutD,
103  typename ElementA,
104  typename LayoutA,
105  typename ElementB,
106  typename LayoutB
107 >
112 ) {
113 
115  ElementD,
116  LayoutD,
117  ElementA,
118  LayoutA,
119  ElementB,
120  LayoutB,
122  > func(d, a, b);
123 
125  d.extent(),
126  func);
127 }
128 
130 template <
131  typename ElementD,
132  typename LayoutD,
133  typename ElementA,
134  typename LayoutA
135 >
139 ) {
140  TensorAdd(d, d, a);
141 }
142 
144 
146 template <
147  typename ElementD,
148  typename LayoutD,
149  typename ElementA,
150  typename LayoutA,
151  typename ElementB,
152  typename LayoutB
153 >
158  ) {
159 
161  ElementD,
162  LayoutD,
163  ElementA,
164  LayoutA,
165  ElementB,
166  LayoutB,
168  > func(d, a, b);
169 
171  d.extent(),
172  func);
173 }
174 
176 template <
177  typename ElementD,
178  typename LayoutD,
179  typename ElementA,
180  typename LayoutA,
181  typename ElementB,
182  typename LayoutB
183 >
187  ) {
188 
189  TensorSub(d, d, a);
190 }
191 
193 
195 template <
196  typename ElementD,
197  typename LayoutD,
198  typename ElementA,
199  typename LayoutA,
200  typename ElementB,
201  typename LayoutB
202 >
207 ) {
208 
210  ElementD,
211  LayoutD,
212  ElementA,
213  LayoutA,
214  ElementB,
215  LayoutB,
217  > func(d, a, b);
218 
220  d.extent(),
221  func);
222 }
223 
225 template <
226  typename ElementD,
227  typename LayoutD,
228  typename ElementA,
229  typename LayoutA
230 >
234 ) {
235  TensorMul(d, d, a);
236 }
237 
239 
241 template <
242  typename ElementD,
243  typename LayoutD,
244  typename ElementA,
245  typename LayoutA,
246  typename ElementB,
247  typename LayoutB
248 >
253 ) {
254 
256  ElementD,
257  LayoutD,
258  ElementA,
259  LayoutA,
260  ElementB,
261  LayoutB,
263  > func(d, a, b);
264 
266  d.extent(),
267  func);
268 }
269 
271 template <
272  typename ElementD,
273  typename LayoutD,
274  typename ElementA,
275  typename LayoutA
276 >
280 ) {
281  TensorMul(d, d, a);
282 }
283 
284 
286 
288 template <
289  typename ElementD,
290  typename LayoutD,
291  typename ElementA,
292  typename LayoutA,
293  typename ElementB,
294  typename LayoutB
295 >
300 ) {
301 
303  ElementD,
304  LayoutD,
305  ElementA,
306  LayoutA,
307  ElementB,
308  LayoutB,
309  cutlass::modulus<ElementD>
310  > func(d, a, b);
311 
313  d.extent(),
314  func);
315 }
316 
318 template <
319  typename ElementD,
320  typename LayoutD,
321  typename ElementA,
322  typename LayoutA
323 >
327 ) {
328  TensorMul(d, d, a);
329 }
330 
332 
333 } // namespace host
334 } // namespace reference
335 } // namespace cutlass
void operator()(Coord< LayoutD::kRank > const &coord) const
Equality check.
Definition: host/tensor_elementwise.h:86
Definition: aligned_buffer.h:35
Helper to apply a binary operator in place.
Definition: host/tensor_elementwise.h:57
void TensorAdd(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Adds two tensors and stores in the destination tensor: d = a + b.
Definition: host/tensor_elementwise.h:108
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
void TensorDiv(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Divides two tensors and stores in the destination tensor: d = a ./ b.
Definition: host/tensor_elementwise.h:249
BinaryFunc func
Definition: host/tensor_elementwise.h:67
Definition: functional.h:46
TensorFuncBinaryOp(TensorView< ElementD, LayoutD > const &view_d_, TensorRef< ElementA, LayoutA > const &ref_a_, TensorRef< ElementB, LayoutB > const &ref_b_, BinaryFunc func=BinaryFunc())
Constructor.
Definition: host/tensor_elementwise.h:77
TensorRef< ElementB, LayoutB > ref_b
Definition: host/tensor_elementwise.h:66
Definition: functional.h:64
void TensorSub(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Subtracts two tensors and stores in the destination tensor: d = a - b.
Definition: host/tensor_elementwise.h:154
TensorView< ElementD, LayoutD > view_d
View of left-hand-side tensor.
Definition: host/tensor_elementwise.h:64
Definition: functional.h:73
void TensorMul(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Multiplies two tensors and stores in the destination tensor: d = a .* b.
Definition: host/tensor_elementwise.h:203
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:43
void TensorModulus(TensorView< ElementD, LayoutD > d, TensorRef< ElementA, LayoutA > a, TensorRef< ElementB, LayoutB > b)
Divides two tensors and stores in the destination tensor: d = a ./ b.
Definition: host/tensor_elementwise.h:296
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
void TensorForEach(Coord< Rank > extent, Func &func)
Iterates over the index space of a tensor.
Definition: host/tensor_foreach.h:87
Definition: functional.h:55
TensorFuncBinaryOp()
Constructor.
Definition: host/tensor_elementwise.h:74
Basic include for CUTLASS.
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
TensorRef< ElementA, LayoutA > ref_a
Definition: host/tensor_elementwise.h:65