CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
functional.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
31 #pragma once
32 
33 #include "cutlass/cutlass.h"
34 #include "cutlass/numeric_types.h"
35 
36 #include "cutlass/complex.h"
37 
38 #include "cutlass/array.h"
39 #include "cutlass/half.h"
40 
41 namespace cutlass {
42 
44 
45 template <typename T>
46 struct plus {
48  T operator()(T lhs, T const &rhs) const {
49  lhs += rhs;
50  return lhs;
51  }
52 };
53 
54 template <typename T>
55 struct minus {
57  T operator()(T lhs, T const &rhs) const {
58  lhs -= rhs;
59  return lhs;
60  }
61 };
62 
63 template <typename T>
64 struct multiplies {
66  T operator()(T lhs, T const &rhs) const {
67  lhs *= rhs;
68  return lhs;
69  }
70 };
71 
72 template <typename T>
73 struct divides {
75  T operator()(T lhs, T const &rhs) const {
76  lhs /= rhs;
77  return lhs;
78  }
79 };
80 
81 
82 template <typename T>
83 struct negate {
85  T operator()(T lhs) const {
86  return -lhs;
87  }
88 };
89 
91 template <typename A, typename B = A, typename C = A>
92 struct multiply_add {
94  C operator()(A const &a, B const &b, C const &c) const {
95  return C(a) * C(b) + c;
96  }
97 };
98 
100 template <typename T>
101 struct xor_add {
103  T operator()(T const &a, T const &b, T const &c) const {
104  return ((a ^ b) + c);
105  }
106 };
107 
109 //
110 // Partial specialization for complex<T> to target four scalar fused multiply-adds.
111 //
113 
115 template <typename T>
116 struct multiply_add<complex<T>, complex<T>, complex<T>> {
119  complex<T> const &a,
120  complex<T> const &b,
121  complex<T> const &c) const {
122 
123  T real = c.real();
124  T imag = c.imag();
125 
126  real += a.real() * b.real();
127  real += -a.imag() * b.imag();
128  imag += a.real() * b.imag();
129  imag += a.imag () * b.real();
130 
131  return complex<T>{
132  real,
133  imag
134  };
135  }
136 };
137 
139 template <typename T>
140 struct multiply_add<complex<T>, T, complex<T>> {
143  complex<T> const &a,
144  T const &b,
145  complex<T> const &c) const {
146 
147  T real = c.real();
148  T imag = c.imag();
149 
150  real += a.real() * b;
151  imag += a.imag () * b;
152 
153  return complex<T>{
154  real,
155  imag
156  };
157  }
158 };
159 
161 template <typename T>
162 struct multiply_add<T, complex<T>, complex<T>> {
165  T const &a,
166  complex<T> const &b,
167  complex<T> const &c) const {
168 
169  T real = c.real();
170  T imag = c.imag();
171 
172  real += a * b.real();
173  imag += a * b.imag();
174 
175  return complex<T>{
176  real,
177  imag
178  };
179  }
180 };
181 
183 //
184 // Partial specializations for Array<T, N>
185 //
187 
188 template <typename T, int N>
189 struct plus<Array<T, N>> {
191  Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
192 
193  Array<T, N> result;
194  plus<T> scalar_op;
195 
197  for (int i = 0; i < N; ++i) {
198  result[i] = scalar_op(lhs[i], rhs[i]);
199  }
200 
201  return result;
202  }
203 
205  Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
206 
207  Array<T, N> result;
208  plus<T> scalar_op;
209 
211  for (int i = 0; i < N; ++i) {
212  result[i] = scalar_op(lhs[i], scalar);
213  }
214 
215  return result;
216  }
217 
219  Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
220 
221  Array<T, N> result;
222  plus<T> scalar_op;
223 
225  for (int i = 0; i < N; ++i) {
226  result[i] = scalar_op(scalar, rhs[i]);
227  }
228 
229  return result;
230  }
231 };
232 
233 
234 template <typename T>
235 struct maximum {
236 
238  T operator()(T const &lhs, T const &rhs) const {
239  return (lhs < rhs ? rhs : lhs);
240  }
241 };
242 
243 template <>
244 struct maximum<float> {
246  float operator()(float const &lhs, float const &rhs) const {
247  return fmaxf(lhs, rhs);
248  }
249 };
250 
251 template <typename T, int N>
252 struct maximum<Array<T, N>> {
253 
255  Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
256 
257  Array<T, N> result;
258  maximum<T> scalar_op;
259 
261  for (int i = 0; i < N; ++i) {
262  result[i] = scalar_op(lhs[i], rhs[i]);
263  }
264 
265  return result;
266  }
267 
269  Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
270 
271  Array<T, N> result;
272  maximum<T> scalar_op;
273 
275  for (int i = 0; i < N; ++i) {
276  result[i] = scalar_op(lhs[i], scalar);
277  }
278 
279  return result;
280  }
281 
283  Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
284 
285  Array<T, N> result;
286  maximum<T> scalar_op;
287 
289  for (int i = 0; i < N; ++i) {
290  result[i] = scalar_op(scalar, rhs[i]);
291  }
292 
293  return result;
294  }
295 };
296 
297 template <typename T>
298 struct minimum {
299 
301  T operator()(T const &lhs, T const &rhs) const {
302  return (rhs < lhs ? rhs : lhs);
303  }
304 };
305 
306 template <>
307 struct minimum<float> {
309  float operator()(float const &lhs, float const &rhs) const {
310  return fminf(lhs, rhs);
311  }
312 };
313 
314 template <typename T, int N>
315 struct minimum<Array<T, N>> {
316 
318  static T scalar_op(T const &lhs, T const &rhs) {
319  return (rhs < lhs ? rhs : lhs);
320  }
321 
323  Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
324 
325  Array<T, N> result;
326  minimum<T> scalar_op;
327 
329  for (int i = 0; i < N; ++i) {
330  result[i] = scalar_op(lhs[i], rhs[i]);
331  }
332 
333  return result;
334  }
335 
337  Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
338 
339  Array<T, N> result;
340  minimum<T> scalar_op;
341 
343  for (int i = 0; i < N; ++i) {
344  result[i] = scalar_op(lhs[i], scalar);
345  }
346 
347  return result;
348  }
349 
351  Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
352 
353  Array<T, N> result;
354  minimum<T> scalar_op;
355 
357  for (int i = 0; i < N; ++i) {
358  result[i] = scalar_op(scalar, rhs[i]);
359  }
360 
361  return result;
362  }
363 };
364 
365 template <typename T, int N>
366 struct minus<Array<T, N>> {
367 
369  Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
370 
371  Array<T, N> result;
372  minus<T> scalar_op;
373 
375  for (int i = 0; i < N; ++i) {
376  result[i] = scalar_op(lhs[i], rhs[i]);
377  }
378 
379  return result;
380  }
381 
383  Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
384 
385  Array<T, N> result;
386  minus<T> scalar_op;
387 
389  for (int i = 0; i < N; ++i) {
390  result[i] = scalar_op(lhs[i], scalar);
391  }
392 
393  return result;
394  }
395 
397  Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
398 
399  Array<T, N> result;
400  minus<T> scalar_op;
401 
403  for (int i = 0; i < N; ++i) {
404  result[i] = scalar_op(scalar, rhs[i]);
405  }
406 
407  return result;
408  }
409 };
410 
411 template <typename T, int N>
412 struct multiplies<Array<T, N>> {
413 
415  Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
416 
417  Array<T, N> result;
418  multiplies<T> scalar_op;
419 
421  for (int i = 0; i < N; ++i) {
422  result[i] = scalar_op(lhs[i], rhs[i]);
423  }
424 
425  return result;
426  }
427 
429  Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
430 
431  Array<T, N> result;
432  multiplies<T> scalar_op;
433 
435  for (int i = 0; i < N; ++i) {
436  result[i] = scalar_op(lhs[i], scalar);
437  }
438 
439  return result;
440  }
441 
443  Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
444 
445  Array<T, N> result;
446  multiplies<T> scalar_op;
447 
449  for (int i = 0; i < N; ++i) {
450  result[i] = scalar_op(scalar, rhs[i]);
451  }
452 
453  return result;
454  }
455 };
456 
457 template <typename T, int N>
458 struct divides<Array<T, N>> {
459 
461  Array<T, N> operator()(Array<T, N> const &lhs, Array<T, N> const &rhs) const {
462 
463  Array<T, N> result;
464  divides<T> scalar_op;
465 
467  for (int i = 0; i < N; ++i) {
468  result[i] = scalar_op(lhs[i], rhs[i]);
469  }
470 
471  return result;
472  }
473 
475  Array<T, N> operator()(Array<T, N> const &lhs, T const &scalar) const {
476 
477  Array<T, N> result;
478  divides<T> scalar_op;
479 
481  for (int i = 0; i < N; ++i) {
482  result[i] = scalar_op(lhs[i], scalar);
483  }
484 
485  return result;
486  }
487 
489  Array<T, N> operator()( T const &scalar, Array<T, N> const &rhs) const {
490 
491  Array<T, N> result;
492  divides<T> scalar_op;
493 
495  for (int i = 0; i < N; ++i) {
496  result[i] = scalar_op(scalar, rhs[i]);
497  }
498 
499  return result;
500  }
501 };
502 
503 
504 template <typename T, int N>
505 struct negate<Array<T, N>> {
506 
508  Array<T, N> operator()(Array<T, N> const &lhs) const {
509 
510  Array<T, N> result;
511  negate<T> scalar_op;
512 
514  for (int i = 0; i < N; ++i) {
515  result[i] = scalar_op(lhs[i]);
516  }
517 
518  return result;
519  }
520 };
521 
523 template <typename T, int N>
524 struct multiply_add<Array<T, N>, Array<T, N>, Array<T, N>> {
525 
527  Array<T, N> operator()(Array<T, N> const &a, Array<T, N> const &b, Array<T, N> const &c) const {
528 
529  Array<T, N> result;
530  multiply_add<T> scalar_op;
531 
533  for (int i = 0; i < N; ++i) {
534  result[i] = scalar_op(a[i], b[i], c[i]);
535  }
536 
537  return result;
538  }
539 
541  Array<T, N> operator()(Array<T, N> const &a, T const &scalar, Array<T, N> const &c) const {
542 
543  Array<T, N> result;
544  multiply_add<T> scalar_op;
545 
547  for (int i = 0; i < N; ++i) {
548  result[i] = scalar_op(a[i], scalar, c[i]);
549  }
550 
551  return result;
552  }
553 
555  Array<T, N> operator()(T const &scalar, Array<T, N> const &b, Array<T, N> const &c) const {
556 
557  Array<T, N> result;
558  multiply_add<T> scalar_op;
559 
561  for (int i = 0; i < N; ++i) {
562  result[i] = scalar_op(scalar, b[i], c[i]);
563  }
564 
565  return result;
566  }
567 };
568 
570 //
571 // Partial specializations for Array<half_t, N> targeting SIMD instructions in device code.
572 //
574 
575 template <int N>
576 struct plus<Array<half_t, N>> {
578  Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
579  Array<half_t, N> result;
580  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
581 
582  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
583  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
584  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
585 
587  for (int i = 0; i < N / 2; ++i) {
588  result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]);
589  }
590 
591  if (N % 2) {
592  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
593  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
594  __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
595 
596  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
597  }
598 
599  #else
600 
602  for (int i = 0; i < N; ++i) {
603  result[i] = lhs[i] + rhs[i];
604  }
605  #endif
606 
607  return result;
608  }
609 
611  Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
612  Array<half_t, N> result;
613  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
614 
615  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
616  __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
617  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
618 
620  for (int i = 0; i < N / 2; ++i) {
621  result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]);
622  }
623 
624  if (N % 2) {
625  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
626  __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
627 
628  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
629  }
630 
631  #else
632 
634  for (int i = 0; i < N; ++i) {
635  result[i] = lhs + rhs[i];
636  }
637  #endif
638 
639  return result;
640  }
641 
643  Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
644  Array<half_t, N> result;
645  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
646 
647  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
648  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
649  __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
650 
652  for (int i = 0; i < N / 2; ++i) {
653  result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair);
654  }
655 
656  if (N % 2) {
657  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
658  __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
659 
660  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
661  }
662 
663  #else
664 
666  for (int i = 0; i < N; ++i) {
667  result[i] = lhs[i] + rhs;
668  }
669  #endif
670 
671  return result;
672  }
673 };
674 
675 template <int N>
676 struct minus<Array<half_t, N>> {
678  Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
679  Array<half_t, N> result;
680  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
681 
682  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
683  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
684  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
685 
687  for (int i = 0; i < N / 2; ++i) {
688  result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]);
689  }
690 
691  if (N % 2) {
692  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
693  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
694  __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
695 
696  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
697  }
698 
699  #else
700 
702  for (int i = 0; i < N; ++i) {
703  result[i] = lhs[i] - rhs[i];
704  }
705  #endif
706 
707  return result;
708  }
709 
711  Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
712  Array<half_t, N> result;
713  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
714 
715  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
716  __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
717  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
718 
720  for (int i = 0; i < N / 2; ++i) {
721  result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]);
722  }
723 
724  if (N % 2) {
725  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
726  __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
727 
728  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
729  }
730 
731  #else
732 
734  for (int i = 0; i < N; ++i) {
735  result[i] = lhs - rhs[i];
736  }
737  #endif
738 
739  return result;
740  }
741 
743  Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
744  Array<half_t, N> result;
745  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
746 
747  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
748  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
749  __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
750 
752  for (int i = 0; i < N / 2; ++i) {
753  result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair);
754  }
755 
756  if (N % 2) {
757  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
758  __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
759 
760  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
761  }
762 
763  #else
764 
766  for (int i = 0; i < N; ++i) {
767  result[i] = lhs[i] - rhs;
768  }
769  #endif
770 
771  return result;
772  }
773 };
774 
775 template <int N>
776 struct multiplies<Array<half_t, N>> {
778  Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
779  Array<half_t, N> result;
780  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
781 
782  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
783  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
784  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
785 
787  for (int i = 0; i < N / 2; ++i) {
788  result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]);
789  }
790 
791  if (N % 2) {
792  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
793  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
794  __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
795 
796  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
797  }
798 
799  #else
800 
802  for (int i = 0; i < N; ++i) {
803  result[i] = lhs[i] * rhs[i];
804  }
805  #endif
806 
807  return result;
808  }
809 
811  Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
812  Array<half_t, N> result;
813  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
814 
815  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
816  __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
817  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
818 
820  for (int i = 0; i < N / 2; ++i) {
821  result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]);
822  }
823 
824  if (N % 2) {
825  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
826 
827  __half d_residual = __hmul(
828  reinterpret_cast<__half const &>(lhs),
829  b_residual_ptr[N - 1]);
830 
831  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
832  }
833 
834  #else
835 
837  for (int i = 0; i < N; ++i) {
838  result[i] = lhs * rhs[i];
839  }
840  #endif
841 
842  return result;
843  }
844 
846  Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
847  Array<half_t, N> result;
848  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
849 
850  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
851  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
852  __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
853 
855  for (int i = 0; i < N / 2; ++i) {
856  result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair);
857  }
858 
859  if (N % 2) {
860  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
861 
862  __half d_residual = __hmul(
863  a_residual_ptr[N - 1],
864  reinterpret_cast<__half const &>(rhs));
865 
866  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
867  }
868 
869  #else
870 
872  for (int i = 0; i < N; ++i) {
873  result[i] = lhs[i] * rhs;
874  }
875  #endif
876 
877  return result;
878  }
879 };
880 
881 template <int N>
882 struct divides<Array<half_t, N>> {
884  Array<half_t, N> operator()(Array<half_t, N> const & lhs, Array<half_t, N> const &rhs) const {
885  Array<half_t, N> result;
886  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
887 
888  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
889  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
890  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
891 
893  for (int i = 0; i < N / 2; ++i) {
894  result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]);
895  }
896 
897  if (N % 2) {
898  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
899  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
900 
901  __half d_residual = __hdiv(
902  a_residual_ptr[N - 1],
903  b_residual_ptr[N - 1]);
904 
905  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
906  }
907 
908  #else
909 
911  for (int i = 0; i < N; ++i) {
912  result[i] = lhs[i] / rhs[i];
913  }
914  #endif
915 
916  return result;
917  }
918 
920  Array<half_t, N> operator()(half_t const & lhs, Array<half_t, N> const &rhs) const {
921  Array<half_t, N> result;
922  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
923 
924  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
925  __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
926  __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs);
927 
929  for (int i = 0; i < N / 2; ++i) {
930  result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]);
931  }
932 
933  if (N % 2) {
934  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs);
935 
936  __half d_residual = __hdiv(
937  reinterpret_cast<__half const &>(lhs),
938  b_residual_ptr[N - 1]);
939 
940  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
941  }
942 
943  #else
944 
946  for (int i = 0; i < N; ++i) {
947  result[i] = lhs / rhs[i];
948  }
949  #endif
950 
951  return result;
952  }
953 
955  Array<half_t, N> operator()(Array<half_t, N> const & lhs, half_t const &rhs) const {
956  Array<half_t, N> result;
957  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
958 
959  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
960  __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs);
961  __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
962 
964  for (int i = 0; i < N / 2; ++i) {
965  result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair);
966  }
967 
968  if (N % 2) {
969  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs);
970 
971  __half d_residual = __hdiv(
972  a_residual_ptr[N - 1],
973  reinterpret_cast<__half const &>(rhs));
974 
975  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
976  }
977 
978  #else
979 
981  for (int i = 0; i < N; ++i) {
982  result[i] = lhs[i] / rhs;
983  }
984  #endif
985 
986  return result;
987  }
988 };
989 
990 template <int N>
991 struct negate<Array<half_t, N>> {
993  Array<half_t, N> operator()(Array<half_t, N> const & lhs) const {
994  Array<half_t, N> result;
995  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
996 
997  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
998  __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs);
999 
1001  for (int i = 0; i < N / 2; ++i) {
1002  result_ptr[i] = __hneg2(source_ptr[i]);
1003  }
1004 
1005  if (N % 2) {
1006  half_t x = lhs[N - 1];
1007  __half lhs_val = -reinterpret_cast<__half const &>(x);
1008  result[N - 1] = reinterpret_cast<half_t const &>(lhs_val);
1009  }
1010 
1011  #else
1012 
1014  for (int i = 0; i < N; ++i) {
1015  result[i] = -lhs[i];
1016  }
1017  #endif
1018 
1019  return result;
1020  }
1021 };
1022 
1024 template <int N>
1025 struct multiply_add<Array<half_t, N>, Array<half_t, N>, Array<half_t, N>> {
1026 
1028  Array<half_t, N> operator()(
1029  Array<half_t, N> const &a,
1030  Array<half_t, N> const &b,
1031  Array<half_t, N> const &c) const {
1032 
1033  Array<half_t, N> result;
1034  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1035 
1036  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1037  __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1038  __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1039  __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1040 
1042  for (int i = 0; i < N / 2; ++i) {
1043  result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]);
1044  }
1045 
1046  if (N % 2) {
1047 
1048  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1049  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1050  __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1051 
1052  __half d_residual = __hfma(
1053  a_residual_ptr[N - 1],
1054  b_residual_ptr[N - 1],
1055  c_residual_ptr[N - 1]);
1056 
1057  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1058  }
1059 
1060  #else
1061 
1063 
1065  for (int i = 0; i < N; ++i) {
1066  result[i] = op(a[i], b[i], c[i]);
1067  }
1068  #endif
1069 
1070  return result;
1071  }
1072 
1074  Array<half_t, N> operator()(
1075  half_t const &a,
1076  Array<half_t, N> const &b,
1077  Array<half_t, N> const &c) const {
1078 
1079  Array<half_t, N> result;
1080  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1081 
1082  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1083  __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
1084  __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1085  __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1086 
1088  for (int i = 0; i < N / 2; ++i) {
1089  result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]);
1090  }
1091 
1092  if (N % 2) {
1093 
1094  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1095  __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1096  __half d_residual = __hfma(
1097  reinterpret_cast<__half const &>(a),
1098  b_residual_ptr[N - 1],
1099  c_residual_ptr[N - 1]);
1100 
1101  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1102  }
1103 
1104  #else
1105 
1107 
1109  for (int i = 0; i < N; ++i) {
1110  result[i] = op(a, b[i], c[i]);
1111  }
1112  #endif
1113 
1114  return result;
1115  }
1116 
1118  Array<half_t, N> operator()(
1119  Array<half_t, N> const &a,
1120  half_t const &b,
1121  Array<half_t, N> const &c) const {
1122 
1123  Array<half_t, N> result;
1124  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1125 
1126  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1127  __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1128  __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
1129  __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c);
1130 
1132  for (int i = 0; i < N / 2; ++i) {
1133  result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]);
1134  }
1135 
1136  if (N % 2) {
1137 
1138  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1139  __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c);
1140 
1141  __half d_residual = __hfma(
1142  a_residual_ptr[N - 1],
1143  reinterpret_cast<__half const &>(b),
1144  c_residual_ptr[N - 1]);
1145 
1146  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1147  }
1148 
1149  #else
1150 
1152 
1154  for (int i = 0; i < N; ++i) {
1155  result[i] = op(a[i], b, c[i]);
1156  }
1157  #endif
1158 
1159  return result;
1160  }
1161 
1163  Array<half_t, N> operator()(
1164  Array<half_t, N> const &a,
1165  Array<half_t, N> const &b,
1166  half_t const &c) const {
1167 
1168  Array<half_t, N> result;
1169  #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
1170 
1171  __half2 *result_ptr = reinterpret_cast<__half2 *>(&result);
1172  __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a);
1173  __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b);
1174  __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
1175 
1177  for (int i = 0; i < N / 2; ++i) {
1178  result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair);
1179  }
1180 
1181  if (N % 2) {
1182 
1183  __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a);
1184  __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b);
1185 
1186  __half d_residual = __hfma(
1187  a_residual_ptr[N - 1],
1188  b_residual_ptr[N - 1],
1189  reinterpret_cast<__half const &>(c));
1190 
1191  result[N - 1] = reinterpret_cast<half_t const &>(d_residual);
1192  }
1193 
1194  #else
1195 
1197 
1199  for (int i = 0; i < N; ++i) {
1200  result[i] = op(a[i], b[i], c);
1201  }
1202  #endif
1203 
1204  return result;
1205  }
1206 };
1207 
1209 
1210 } // namespace cutlass
Fused multiply-add.
Definition: functional.h:92
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:351
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:578
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b, T const &c) const
Definition: functional.h:103
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:955
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:383
CUTLASS_HOST_DEVICE float const & imag(cuFloatComplex const &z)
Returns the imaginary part of the complex number.
Definition: complex.h:72
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:678
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:48
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:269
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:323
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:611
Defines a class for using IEEE half-precision floating-point types in host or device code...
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:920
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:778
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:191
Definition: functional.h:298
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:255
Definition: functional.h:235
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE T operator()(T lhs) const
Definition: functional.h:85
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:811
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:205
CUTLASS_HOST_DEVICE float const & real(cuFloatComplex const &z)
Returns the real part of the complex number.
Definition: complex.h:56
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:743
CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const
Definition: functional.h:309
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:337
CUTLASS_HOST_DEVICE T const & imag() const
Accesses the imaginary part of the complex number.
Definition: complex.h:240
CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const
Definition: functional.h:94
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, half_t const &c) const
Definition: functional.h:1163
Definition: functional.h:46
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:369
CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const
Definition: functional.h:238
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:415
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:397
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:66
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const
Definition: functional.h:246
Definition: functional.h:83
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:75
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, T const &scalar, Array< T, N > const &c) const
Definition: functional.h:541
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c) const
Definition: functional.h:527
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:443
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:429
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:489
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:884
Definition: functional.h:73
CUTLASS_HOST_DEVICE T const & real() const
Accesses the real part of the complex number.
Definition: complex.h:232
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:643
Definition: complex.h:92
CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const
Definition: functional.h:301
static CUTLASS_HOST_DEVICE T scalar_op(T const &lhs, T const &rhs)
Definition: functional.h:318
CUTLASS_HOST_DEVICE complex< T > operator()(T const &a, complex< T > const &b, complex< T > const &c) const
Definition: functional.h:164
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const
Definition: functional.h:1074
CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, T const &b, complex< T > const &c) const
Definition: functional.h:142
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs) const
Definition: functional.h:508
Fused multiply-add.
Definition: functional.h:101
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:283
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:475
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:57
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, half_t const &b, Array< half_t, N > const &c) const
Definition: functional.h:1118
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:219
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:711
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &b, Array< T, N > const &c) const
Definition: functional.h:555
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:461
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs) const
Definition: functional.h:993
Definition: functional.h:55
CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, complex< T > const &b, complex< T > const &c) const
Definition: functional.h:118
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:846
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const
Definition: functional.h:1028