91 template <
typename A,
typename B = A,
typename C = A>
95 return C(a) * C(b) + c;
100 template <
typename T>
104 return ((a ^ b) + c);
115 template <
typename T>
139 template <
typename T>
150 real += a.
real() * b;
151 imag += a.
imag () * b;
161 template <
typename T>
172 real += a * b.
real();
173 imag += a * b.
imag();
188 template <
typename T,
int N>
191 Array<T, N>
operator()(Array<T, N>
const &lhs, Array<T, N>
const &rhs)
const {
197 for (
int i = 0; i < N; ++i) {
198 result[i] = scalar_op(lhs[i], rhs[i]);
205 Array<T, N>
operator()(Array<T, N>
const &lhs, T
const &scalar)
const {
211 for (
int i = 0; i < N; ++i) {
212 result[i] = scalar_op(lhs[i], scalar);
219 Array<T, N>
operator()( T
const &scalar, Array<T, N>
const &rhs)
const {
225 for (
int i = 0; i < N; ++i) {
226 result[i] = scalar_op(scalar, rhs[i]);
234 template <
typename T>
239 return (lhs < rhs ? rhs : lhs);
247 return fmaxf(lhs, rhs);
251 template <
typename T,
int N>
255 Array<T, N>
operator()(Array<T, N>
const &lhs, Array<T, N>
const &rhs)
const {
261 for (
int i = 0; i < N; ++i) {
262 result[i] = scalar_op(lhs[i], rhs[i]);
269 Array<T, N>
operator()(Array<T, N>
const &lhs, T
const &scalar)
const {
275 for (
int i = 0; i < N; ++i) {
276 result[i] = scalar_op(lhs[i], scalar);
283 Array<T, N>
operator()( T
const &scalar, Array<T, N>
const &rhs)
const {
289 for (
int i = 0; i < N; ++i) {
290 result[i] = scalar_op(scalar, rhs[i]);
297 template <
typename T>
302 return (rhs < lhs ? rhs : lhs);
310 return fminf(lhs, rhs);
314 template <
typename T,
int N>
319 return (rhs < lhs ? rhs : lhs);
323 Array<T, N>
operator()(Array<T, N>
const &lhs, Array<T, N>
const &rhs)
const {
329 for (
int i = 0; i < N; ++i) {
330 result[i] = scalar_op(lhs[i], rhs[i]);
337 Array<T, N>
operator()(Array<T, N>
const &lhs, T
const &scalar)
const {
343 for (
int i = 0; i < N; ++i) {
344 result[i] = scalar_op(lhs[i], scalar);
351 Array<T, N>
operator()( T
const &scalar, Array<T, N>
const &rhs)
const {
357 for (
int i = 0; i < N; ++i) {
358 result[i] = scalar_op(scalar, rhs[i]);
365 template <
typename T,
int N>
369 Array<T, N>
operator()(Array<T, N>
const &lhs, Array<T, N>
const &rhs)
const {
375 for (
int i = 0; i < N; ++i) {
376 result[i] = scalar_op(lhs[i], rhs[i]);
383 Array<T, N>
operator()(Array<T, N>
const &lhs, T
const &scalar)
const {
389 for (
int i = 0; i < N; ++i) {
390 result[i] = scalar_op(lhs[i], scalar);
397 Array<T, N>
operator()( T
const &scalar, Array<T, N>
const &rhs)
const {
403 for (
int i = 0; i < N; ++i) {
404 result[i] = scalar_op(scalar, rhs[i]);
411 template <
typename T,
int N>
415 Array<T, N>
operator()(Array<T, N>
const &lhs, Array<T, N>
const &rhs)
const {
421 for (
int i = 0; i < N; ++i) {
422 result[i] = scalar_op(lhs[i], rhs[i]);
429 Array<T, N>
operator()(Array<T, N>
const &lhs, T
const &scalar)
const {
435 for (
int i = 0; i < N; ++i) {
436 result[i] = scalar_op(lhs[i], scalar);
443 Array<T, N>
operator()( T
const &scalar, Array<T, N>
const &rhs)
const {
449 for (
int i = 0; i < N; ++i) {
450 result[i] = scalar_op(scalar, rhs[i]);
457 template <
typename T,
int N>
461 Array<T, N>
operator()(Array<T, N>
const &lhs, Array<T, N>
const &rhs)
const {
467 for (
int i = 0; i < N; ++i) {
468 result[i] = scalar_op(lhs[i], rhs[i]);
475 Array<T, N>
operator()(Array<T, N>
const &lhs, T
const &scalar)
const {
481 for (
int i = 0; i < N; ++i) {
482 result[i] = scalar_op(lhs[i], scalar);
489 Array<T, N>
operator()( T
const &scalar, Array<T, N>
const &rhs)
const {
495 for (
int i = 0; i < N; ++i) {
496 result[i] = scalar_op(scalar, rhs[i]);
504 template <
typename T,
int N>
514 for (
int i = 0; i < N; ++i) {
515 result[i] = scalar_op(lhs[i]);
523 template <
typename T,
int N>
527 Array<T, N>
operator()(Array<T, N>
const &a, Array<T, N>
const &b, Array<T, N>
const &c)
const {
533 for (
int i = 0; i < N; ++i) {
534 result[i] = scalar_op(a[i], b[i], c[i]);
541 Array<T, N>
operator()(Array<T, N>
const &a, T
const &scalar, Array<T, N>
const &c)
const {
547 for (
int i = 0; i < N; ++i) {
548 result[i] = scalar_op(a[i], scalar, c[i]);
555 Array<T, N>
operator()(T
const &scalar, Array<T, N>
const &b, Array<T, N>
const &c)
const {
561 for (
int i = 0; i < N; ++i) {
562 result[i] = scalar_op(scalar, b[i], c[i]);
578 Array<half_t, N>
operator()(Array<half_t, N>
const & lhs, Array<half_t, N>
const &rhs)
const {
579 Array<half_t, N> result;
580 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 582 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
583 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
584 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
587 for (
int i = 0; i < N / 2; ++i) {
588 result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]);
592 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
593 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
594 __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
596 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
602 for (
int i = 0; i < N; ++i) {
603 result[i] = lhs[i] + rhs[i];
612 Array<half_t, N> result;
613 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 615 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
616 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
617 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
620 for (
int i = 0; i < N / 2; ++i) {
621 result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]);
625 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
626 __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
628 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
634 for (
int i = 0; i < N; ++i) {
635 result[i] = lhs + rhs[i];
644 Array<half_t, N> result;
645 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 647 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
648 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
649 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
652 for (
int i = 0; i < N / 2; ++i) {
653 result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair);
657 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
658 __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
660 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
666 for (
int i = 0; i < N; ++i) {
667 result[i] = lhs[i] + rhs;
678 Array<half_t, N>
operator()(Array<half_t, N>
const & lhs, Array<half_t, N>
const &rhs)
const {
679 Array<half_t, N> result;
680 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 682 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
683 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
684 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
687 for (
int i = 0; i < N / 2; ++i) {
688 result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]);
692 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
693 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
694 __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
696 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
702 for (
int i = 0; i < N; ++i) {
703 result[i] = lhs[i] - rhs[i];
712 Array<half_t, N> result;
713 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 715 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
716 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
717 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
720 for (
int i = 0; i < N / 2; ++i) {
721 result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]);
725 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
726 __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]);
728 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
734 for (
int i = 0; i < N; ++i) {
735 result[i] = lhs - rhs[i];
744 Array<half_t, N> result;
745 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 747 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
748 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
749 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
752 for (
int i = 0; i < N / 2; ++i) {
753 result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair);
757 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
758 __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs));
760 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
766 for (
int i = 0; i < N; ++i) {
767 result[i] = lhs[i] - rhs;
778 Array<half_t, N>
operator()(Array<half_t, N>
const & lhs, Array<half_t, N>
const &rhs)
const {
779 Array<half_t, N> result;
780 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 782 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
783 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
784 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
787 for (
int i = 0; i < N / 2; ++i) {
788 result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]);
792 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
793 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
794 __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]);
796 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
802 for (
int i = 0; i < N; ++i) {
803 result[i] = lhs[i] * rhs[i];
812 Array<half_t, N> result;
813 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 815 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
816 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
817 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
820 for (
int i = 0; i < N / 2; ++i) {
821 result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]);
825 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
827 __half d_residual = __hmul(
828 reinterpret_cast<__half const &>(lhs),
829 b_residual_ptr[N - 1]);
831 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
837 for (
int i = 0; i < N; ++i) {
838 result[i] = lhs * rhs[i];
847 Array<half_t, N> result;
848 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 850 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
851 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
852 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
855 for (
int i = 0; i < N / 2; ++i) {
856 result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair);
860 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
862 __half d_residual = __hmul(
863 a_residual_ptr[N - 1],
864 reinterpret_cast<__half const &>(rhs));
866 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
872 for (
int i = 0; i < N; ++i) {
873 result[i] = lhs[i] * rhs;
884 Array<half_t, N>
operator()(Array<half_t, N>
const & lhs, Array<half_t, N>
const &rhs)
const {
885 Array<half_t, N> result;
886 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 888 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
889 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
890 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
893 for (
int i = 0; i < N / 2; ++i) {
894 result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]);
898 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
899 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
901 __half d_residual = __hdiv(
902 a_residual_ptr[N - 1],
903 b_residual_ptr[N - 1]);
905 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
911 for (
int i = 0; i < N; ++i) {
912 result[i] = lhs[i] / rhs[i];
921 Array<half_t, N> result;
922 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 924 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
925 __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs));
926 __half2
const *rhs_ptr =
reinterpret_cast<__half2
const *
>(&rhs);
929 for (
int i = 0; i < N / 2; ++i) {
930 result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]);
934 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&rhs);
936 __half d_residual = __hdiv(
937 reinterpret_cast<__half const &>(lhs),
938 b_residual_ptr[N - 1]);
940 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
946 for (
int i = 0; i < N; ++i) {
947 result[i] = lhs / rhs[i];
956 Array<half_t, N> result;
957 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 959 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
960 __half2
const *lhs_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
961 __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs));
964 for (
int i = 0; i < N / 2; ++i) {
965 result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair);
969 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&lhs);
971 __half d_residual = __hdiv(
972 a_residual_ptr[N - 1],
973 reinterpret_cast<__half const &>(rhs));
975 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
981 for (
int i = 0; i < N; ++i) {
982 result[i] = lhs[i] / rhs;
993 Array<half_t, N>
operator()(Array<half_t, N>
const & lhs)
const {
994 Array<half_t, N> result;
995 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 997 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
998 __half2
const *source_ptr =
reinterpret_cast<__half2
const *
>(&lhs);
1001 for (
int i = 0; i < N / 2; ++i) {
1002 result_ptr[i] = __hneg2(source_ptr[i]);
1007 __half lhs_val = -
reinterpret_cast<__half
const &
>(x);
1008 result[N - 1] =
reinterpret_cast<half_t const &
>(lhs_val);
1014 for (
int i = 0; i < N; ++i) {
1015 result[i] = -lhs[i];
1029 Array<half_t, N>
const &a,
1030 Array<half_t, N>
const &b,
1031 Array<half_t, N>
const &c)
const {
1033 Array<half_t, N> result;
1034 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 1036 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
1037 __half2
const *a_ptr =
reinterpret_cast<__half2
const *
>(&a);
1038 __half2
const *b_ptr =
reinterpret_cast<__half2
const *
>(&b);
1039 __half2
const *c_ptr =
reinterpret_cast<__half2
const *
>(&c);
1042 for (
int i = 0; i < N / 2; ++i) {
1043 result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]);
1048 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&a);
1049 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&b);
1050 __half
const *c_residual_ptr =
reinterpret_cast<__half
const *
>(&c);
1052 __half d_residual = __hfma(
1053 a_residual_ptr[N - 1],
1054 b_residual_ptr[N - 1],
1055 c_residual_ptr[N - 1]);
1057 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
1065 for (
int i = 0; i < N; ++i) {
1066 result[i] = op(a[i], b[i], c[i]);
1076 Array<half_t, N>
const &b,
1077 Array<half_t, N>
const &c)
const {
1079 Array<half_t, N> result;
1080 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 1082 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
1083 __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a));
1084 __half2
const *b_ptr =
reinterpret_cast<__half2
const *
>(&b);
1085 __half2
const *c_ptr =
reinterpret_cast<__half2
const *
>(&c);
1088 for (
int i = 0; i < N / 2; ++i) {
1089 result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]);
1094 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&b);
1095 __half
const *c_residual_ptr =
reinterpret_cast<__half
const *
>(&c);
1096 __half d_residual = __hfma(
1097 reinterpret_cast<__half const &>(a),
1098 b_residual_ptr[N - 1],
1099 c_residual_ptr[N - 1]);
1101 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
1109 for (
int i = 0; i < N; ++i) {
1110 result[i] = op(a, b[i], c[i]);
1119 Array<half_t, N>
const &a,
1121 Array<half_t, N>
const &c)
const {
1123 Array<half_t, N> result;
1124 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 1126 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
1127 __half2
const *a_ptr =
reinterpret_cast<__half2
const *
>(&a);
1128 __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b));
1129 __half2
const *c_ptr =
reinterpret_cast<__half2
const *
>(&c);
1132 for (
int i = 0; i < N / 2; ++i) {
1133 result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]);
1138 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&a);
1139 __half
const *c_residual_ptr =
reinterpret_cast<__half
const *
>(&c);
1141 __half d_residual = __hfma(
1142 a_residual_ptr[N - 1],
1143 reinterpret_cast<__half const &>(b),
1144 c_residual_ptr[N - 1]);
1146 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
1154 for (
int i = 0; i < N; ++i) {
1155 result[i] = op(a[i], b, c[i]);
1164 Array<half_t, N>
const &a,
1165 Array<half_t, N>
const &b,
1168 Array<half_t, N> result;
1169 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) 1171 __half2 *result_ptr =
reinterpret_cast<__half2 *
>(&result);
1172 __half2
const *a_ptr =
reinterpret_cast<__half2
const *
>(&a);
1173 __half2
const *b_ptr =
reinterpret_cast<__half2
const *
>(&b);
1174 __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c));
1177 for (
int i = 0; i < N / 2; ++i) {
1178 result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair);
1183 __half
const *a_residual_ptr =
reinterpret_cast<__half
const *
>(&a);
1184 __half
const *b_residual_ptr =
reinterpret_cast<__half
const *
>(&b);
1186 __half d_residual = __hfma(
1187 a_residual_ptr[N - 1],
1188 b_residual_ptr[N - 1],
1189 reinterpret_cast<__half const &>(c));
1191 result[N - 1] =
reinterpret_cast<half_t const &
>(d_residual);
1199 for (
int i = 0; i < N; ++i) {
1200 result[i] = op(a[i], b[i], c);
Fused multiply-add.
Definition: functional.h:92
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:351
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:578
Definition: aligned_buffer.h:35
CUTLASS_HOST_DEVICE T operator()(T const &a, T const &b, T const &c) const
Definition: functional.h:103
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:955
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:383
CUTLASS_HOST_DEVICE float const & imag(cuFloatComplex const &z)
Returns the imaginary part of the complex number.
Definition: complex.h:72
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:678
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:48
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:269
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:323
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:611
Defines a class for using IEEE half-precision floating-point types in host or device code...
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:920
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:778
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:191
Definition: functional.h:298
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:255
Definition: functional.h:235
IEEE half-precision floating-point type.
Definition: half.h:126
CUTLASS_HOST_DEVICE T operator()(T lhs) const
Definition: functional.h:85
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:811
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:205
CUTLASS_HOST_DEVICE float const & real(cuFloatComplex const &z)
Returns the real part of the complex number.
Definition: complex.h:56
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:743
CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const
Definition: functional.h:309
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:337
CUTLASS_HOST_DEVICE T const & imag() const
Accesses the imaginary part of the complex number.
Definition: complex.h:240
CUTLASS_HOST_DEVICE C operator()(A const &a, B const &b, C const &c) const
Definition: functional.h:94
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, half_t const &c) const
Definition: functional.h:1163
Definition: functional.h:46
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:369
CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const
Definition: functional.h:238
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:415
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:397
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:66
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const
Definition: functional.h:246
Definition: functional.h:83
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:75
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, T const &scalar, Array< T, N > const &c) const
Definition: functional.h:541
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &a, Array< T, N > const &b, Array< T, N > const &c) const
Definition: functional.h:527
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:443
Definition: functional.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
Top-level include for all CUTLASS numeric types.
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:429
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:489
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:884
Definition: functional.h:73
CUTLASS_HOST_DEVICE T const & real() const
Accesses the real part of the complex number.
Definition: complex.h:232
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:643
CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const
Definition: functional.h:301
static CUTLASS_HOST_DEVICE T scalar_op(T const &lhs, T const &rhs)
Definition: functional.h:318
CUTLASS_HOST_DEVICE complex< T > operator()(T const &a, complex< T > const &b, complex< T > const &c) const
Definition: functional.h:164
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const
Definition: functional.h:1074
CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, T const &b, complex< T > const &c) const
Definition: functional.h:142
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs) const
Definition: functional.h:508
Fused multiply-add.
Definition: functional.h:101
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:283
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, T const &scalar) const
Definition: functional.h:475
CUTLASS_HOST_DEVICE T operator()(T lhs, T const &rhs) const
Definition: functional.h:57
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, half_t const &b, Array< half_t, N > const &c) const
Definition: functional.h:1118
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &rhs) const
Definition: functional.h:219
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(half_t const &lhs, Array< half_t, N > const &rhs) const
Definition: functional.h:711
CUTLASS_HOST_DEVICE Array< T, N > operator()(T const &scalar, Array< T, N > const &b, Array< T, N > const &c) const
Definition: functional.h:555
CUTLASS_HOST_DEVICE Array< T, N > operator()(Array< T, N > const &lhs, Array< T, N > const &rhs) const
Definition: functional.h:461
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs) const
Definition: functional.h:993
Definition: functional.h:55
CUTLASS_HOST_DEVICE complex< T > operator()(complex< T > const &a, complex< T > const &b, complex< T > const &c) const
Definition: functional.h:118
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &lhs, half_t const &rhs) const
Definition: functional.h:846
Basic include for CUTLASS.
CUTLASS_HOST_DEVICE Array< half_t, N > operator()(Array< half_t, N > const &a, Array< half_t, N > const &b, Array< half_t, N > const &c) const
Definition: functional.h:1028