CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
inner_product.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/cutlass.h"
31 #include "cutlass/array.h"
32 
33 namespace cutlass {
34 namespace reference {
35 namespace detail {
36 
38 
40 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a
41  // host-only type
42 template <typename Atype, typename Btype, typename Ctype>
44 Ctype inner_product(Atype a, Btype b, Ctype c) {
45  return Ctype(a) * Ctype(b) + c;
46 }
47 
49 template <>
51 int inner_product<Array<bin1_t, 32>, Array<bin1_t, 32>, int>(
52  Array<bin1_t, 32> a,
53  Array<bin1_t, 32> b,
54  int c) {
55 
56  int accum = 0;
57  for (int bit = 0; bit < 32; bit++) {
58  accum += a[bit] ^ b[bit];
59  }
60  return accum + c;
61 }
62 
63 /*
65 template <>
66 CUTLASS_HOST_DEVICE
67 int inner_product<Array<int4b_t, 8>, Array<int4b_t, 8>, int>(
68  Array<int4b_t, 8> a,
69  Array<int4b_t, 8> b,
70  int c) {
71 
72  int accum = 0;
73  for (int k = 0; k < 8; k++) {
74  accum += a[k] * b[k];
75  }
76  return accum + c;
77 }
78 
80 template <>
81 CUTLASS_HOST_DEVICE
82 int inner_product<Array<uint4b_t, 8>, Array<uint4b_t, 8>, int>(
83  Array<uint4b_t, 8> a,
84  Array<uint4b_t, 8> b,
85  int c) {
86 
87  int accum = 0;
88  for (int k = 0; k < 8; k++) {
89  accum += a[k] * b[k];
90  }
91  return accum + c;
92 }
93 */
94 
96 
97 template <typename SrcType, typename DstType>
98 struct Cast {
99  // Default behavior: convert to the destination type
100 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
101  // host-only type
103  static DstType apply(SrcType src) { return static_cast<DstType>(src); };
104 };
105 
106 template <>
107 struct Cast<float, int8_t> {
109  static int8_t apply(float src) {
110  // Clamp to the range of signed 8-bit integers.
111  return static_cast<int8_t>(fmaxf(-128.f, fminf(127.f, src)));
112  };
113 };
114 
115 template <>
116 struct Cast<float, uint8_t> {
118  static uint8_t apply(float src) {
119  // Clamp to the range of signed 8-bit integers.
120  return static_cast<uint8_t>(fmaxf(0.f, fminf(255.f, src)));
121  };
122 };
123 
125 
126 } // namespace detail
127 } // namespace reference
128 } // namespace cutlass
129 
Definition: aligned_buffer.h:35
static CUTLASS_HOST_DEVICE DstType apply(SrcType src)
Definition: inner_product.h:103
CUTLASS_HOST_DEVICE Ctype inner_product(Atype a, Btype b, Ctype c)
Template function to compute an inner product.
Definition: inner_product.h:44
Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe ...
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:89
static CUTLASS_HOST_DEVICE uint8_t apply(float src)
Definition: inner_product.h:118
Definition: inner_product.h:98
Basic include for CUTLASS.
static CUTLASS_HOST_DEVICE int8_t apply(float src)
Definition: inner_product.h:109