42 template <
typename Op,
typename T>
61 template <
typename T,
int N>
72 for (
auto i = 0; i < N; ++i) {
73 result[0] = scalar_reduce(result[0], in[i]);
87 Array<half_t, 1>
operator()(Array<half_t, N>
const &input) {
89 Array<half_t, 1> result;
94 result[0] = input.front();
98 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) 101 Array<half_t, 1>
const *in_ptr_half =
reinterpret_cast<Array<half_t, 1>
const *
>(&input);
102 Array<half_t, 2>
const *in_ptr_half2 =
reinterpret_cast<Array<half_t, 2>
const *
>(&input);
103 __half2
const *x_in_half2 =
reinterpret_cast<__half2
const *
>(in_ptr_half2);
106 __half2 tmp_result = x_in_half2[0];
109 for (
int i = 1; i < N/2; ++i) {
111 tmp_result = __hadd2(x_in_half2[i], tmp_result);
115 result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
121 Array<half_t, 1> tmp_last;
122 Array<half_t, 1> *tmp_last_ptr = &tmp_last;
123 tmp_last_ptr[0] = in_ptr_half[N-1];
124 last_element =
reinterpret_cast<__half
const &
>(tmp_last);
126 result_d = __hadd(result_d, last_element);
130 Array<half_t, 1> *result_ptr = &result;
131 *result_ptr =
reinterpret_cast<Array<half_t, 1> &
>(result_d);
139 for (
auto i = 0; i < N; ++i) {
141 result[0] = scalar_reduce(result[0], input[i]);
163 Array<half_t, 1> result;
168 result[0] = input.front();
172 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600) 177 __half2
const *x_in_half2 =
reinterpret_cast<__half2
const *
>(in_ptr_half2);
180 __half2 tmp_result = x_in_half2[0];
183 for (
int i = 1; i < N/2; ++i) {
185 tmp_result = __hadd2(x_in_half2[i], tmp_result);
189 result_d = __hadd(__low2half(tmp_result), __high2half(tmp_result));
197 tmp_last_ptr[0] = in_ptr_half[N-1];
198 last_element =
reinterpret_cast<__half
const &
>(tmp_last);
200 result_d = __hadd(result_d, last_element);
204 Array<half_t, 1> *result_ptr = &result;
205 *result_ptr =
reinterpret_cast<Array<half_t, 1> &
>(result_d);
213 for (
auto i = 0; i < N; ++i) {
215 result[0] = scalar_reduce(result[0], input[i]);
Definition: aligned_buffer.h:35
Defines a class for using IEEE half-precision floating-point types in host or device code...
Aligned array type.
Definition: array.h:511
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE Array< half_t, 1 > operator()(AlignedArray< half_t, N > const &input)
Definition: reduce.h:161
Definition: functional.h:46
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE Array< T, 1 > operator()(Array< T, N > const &in) const
Definition: reduce.h:65
CUTLASS_HOST_DEVICE Array< half_t, 1 > operator()(Array< half_t, N > const &input)
Definition: reduce.h:87
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: reduce.h:52
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...
Structure to compute the thread level reduction.
Definition: reduce.h:43