CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
reduce.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/numeric_types.h"
33 #include "cutlass/array.h"
34 #include "cutlass/half.h"
35 #include "cutlass/functional.h"
36 
37 namespace cutlass {
38 namespace reduction {
39 namespace thread {
40 
42 template <typename Op, typename T>
43 struct Reduce;
44 
46 
48 template <typename T>
49 struct Reduce< plus<T>, T > {
50 
52  T operator()(T lhs, T const &rhs) const {
53  plus<T> _op;
54  return _op(lhs, rhs);
55  }
56 };
57 
59 
61 template <typename T, int N>
62 struct Reduce < plus<T>, Array<T, N>> {
63 
65  Array<T, 1> operator()(Array<T, N> const &in) const {
66 
67  Array<T, 1> result;
68  Reduce< plus<T>, T > scalar_reduce;
69  result.clear();
70 
72  for (auto i = 0; i < N; ++i) {
73  result[0] = scalar_reduce(result[0], in[i]);
74  }
75 
76  return result;
77  }
78 };
79 
81 
83 template <int N>
84 struct Reduce < plus<half_t>, Array<half_t, N> > {
85 
87  Array<half_t, 1> operator()(Array<half_t, N> const &input) {
88 
89  Array<half_t, 1> result;
90 
91  // If there is only 1 element - there is nothing to reduce
92  if( N ==1 ){
93 
94  result[0] = input.front();
95 
96  } else {
97 
98  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
99 
100  __half result_d;
101  Array<half_t, 1> const *in_ptr_half = reinterpret_cast<Array<half_t, 1> const *>(&input);
102  Array<half_t, 2> const *in_ptr_half2 = reinterpret_cast<Array<half_t, 2> const *>(&input);
103  __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
104 
105  // Set initial result = first half2, in case N==2
106  __half2 tmp_result = x_in_half2[0];
107 
109  for (int i = 1; i < N/2; ++i) {
110 
111  tmp_result = __hadd2(x_in_half2[i], tmp_result);
112 
113  }
114 
115  result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
116 
117  // One final step is needed for odd "N" (to add the (N-1)th element)
118  if( N%2 ){
119 
120  __half last_element;
121  Array<half_t, 1> tmp_last;
122  Array<half_t, 1> *tmp_last_ptr = &tmp_last;
123  tmp_last_ptr[0] = in_ptr_half[N-1];
124  last_element = reinterpret_cast<__half const &>(tmp_last);
125 
126  result_d = __hadd(result_d, last_element);
127 
128  }
129 
130  Array<half_t, 1> *result_ptr = &result;
131  *result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
132 
133  #else
134 
135  Reduce< plus<half_t>, half_t > scalar_reduce;
136  result.clear();
137 
139  for (auto i = 0; i < N; ++i) {
140 
141  result[0] = scalar_reduce(result[0], input[i]);
142 
143  }
144 
145  #endif
146  }
147 
148  return result;
149 
150  }
151 };
152 
153 
155 
157 template <int N>
159 
161  Array<half_t, 1> operator()(AlignedArray<half_t, N> const &input) {
162 
163  Array<half_t, 1> result;
164 
165  // If there is only 1 element - there is nothing to reduce
166  if( N ==1 ){
167 
168  result[0] = input.front();
169 
170  } else {
171 
172  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)
173 
174  __half result_d;
175  AlignedArray<half_t, 1> const *in_ptr_half = reinterpret_cast<AlignedArray<half_t, 1> const *>(&input);
176  AlignedArray<half_t, 2> const *in_ptr_half2 = reinterpret_cast<AlignedArray<half_t, 2> const *>(&input);
177  __half2 const *x_in_half2 = reinterpret_cast<__half2 const *>(in_ptr_half2);
178 
179  // Set initial result = first half2, in case N==2
180  __half2 tmp_result = x_in_half2[0];
181 
183  for (int i = 1; i < N/2; ++i) {
184 
185  tmp_result = __hadd2(x_in_half2[i], tmp_result);
186 
187  }
188 
189  result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
190 
191  // One final step is needed for odd "N" (to add the (N-1)th element)
192  if( N%2 ){
193 
194  __half last_element;
195  AlignedArray<half_t, 1> tmp_last;
196  AlignedArray<half_t, 1> *tmp_last_ptr = &tmp_last;
197  tmp_last_ptr[0] = in_ptr_half[N-1];
198  last_element = reinterpret_cast<__half const &>(tmp_last);
199 
200  result_d = __hadd(result_d, last_element);
201 
202  }
203 
204  Array<half_t, 1> *result_ptr = &result;
205  *result_ptr = reinterpret_cast<Array<half_t, 1> &>(result_d);
206 
207  #else
208 
209  Reduce< plus<half_t>, half_t > scalar_reduce;
210  result.clear();
211 
213  for (auto i = 0; i < N; ++i) {
214 
215  result[0] = scalar_reduce(result[0], input[i]);
216 
217  }
218 
219  #endif
220  }
221 
222  return result;
223 
224  }
225 };
226 }
227 }
228 }
Definition: aligned_buffer.h:35
Defines a class for using IEEE half-precision floating-point types in host or device code...
Aligned array type.
Definition: array.h:511
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE Array< half_t, 1 > operator()(AlignedArray< half_t, N > const &input)
Definition: reduce.h:161
Definition: functional.h:46
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE Array< T, 1 > operator()(Array< T, N > const &in) const
Definition: reduce.h:65
CUTLASS_HOST_DEVICE Array< half_t, 1 > operator()(Array< half_t, N > const &input)
Definition: reduce.h:87
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: reduce.h:52
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Structure to compute the thread level reduction.
Definition: reduce.h:43