CUTLASS
CUDA Templates for Linear Algebra Subroutines and Solvers
tools/util/include/cutlass/util/reference/host/gemm_complex.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/coord.h"
32 #include "cutlass/complex.h"
33 #include "cutlass/numeric_types.h"
34 #include "cutlass/functional.h"
36 
37 #include "cutlass/matrix_traits.h"
38 #include "cutlass/tensor_view.h"
39 #include "cutlass/gemm/gemm.h"
40 
41 namespace cutlass {
42 namespace reference {
43 namespace host {
44 
46 
54 template <
55  typename ElementA,
56  typename LayoutA,
57  typename ElementB,
58  typename LayoutB,
59  typename ElementC,
60  typename LayoutC,
61  typename ScalarType,
62  typename ComputeType,
63  typename ConvertOp = NumericConverter<ElementC, ScalarType>,
64  typename InnerProductOp = multiply_add<ComputeType>
65 >
67  gemm::GemmCoord problem_size,
68  ScalarType alpha,
70  ComplexTransform transform_a,
72  ComplexTransform transform_b,
73  ScalarType beta,
75  ComputeType initial_accum) {
76 
78  LayoutA::kRank == 2 &&
79  LayoutB::kRank == 2 &&
80  LayoutC::kRank == 2, "Tensors must be of rank 2");
81 
82  // Note: batch is ignored.
83  int const M = problem_size.m();
84  int const N = problem_size.n();
85  int const K = problem_size.k();
86 
87  // Blocking necessary to speedup reference implementation
88  int const Mblock = 16;
89  int const Nblock = 16;
90 
91  ConvertOp convert_op;
92  InnerProductOp inner_product_op;
93 
94  for (int row_block = 0; row_block < M; row_block += Mblock) {
95  for (int col_block = 0; col_block < N; col_block += Nblock) {
96 
97  ComputeType accum[Mblock][Nblock];
98 
99  for (int j = 0; j < Nblock; j++) {
100  for (int i = 0; i < Mblock; i++) {
101  accum[i][j] = initial_accum;
102  }
103  }
104 
105  for (int k_block = 0; k_block < K; ++k_block) {
106  for (int j = 0; j < Nblock; j++) {
107  for (int i = 0; i < Mblock; i++) {
108  int row = row_block + i;
109  int col = col_block + j;
110 
111  if (row < M && col < N) {
112  ElementA a = tensor_a.at(MatrixCoord(row, k_block));
113  ElementB b = tensor_b.at(MatrixCoord(k_block, col));
114 
115  ComputeType a_ik = ComputeType(a);
116  ComputeType b_kj = ComputeType(b);
117 
118  if (transform_a == ComplexTransform::kConjugate) {
119  a_ik = conj(a_ik);
120  }
121 
122  if (transform_b == ComplexTransform::kConjugate) {
123  b_kj = conj(b_kj);
124  }
125 
126  accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]);
127  }
128  }
129  }
130  }
131 
132  for (int j = 0; j < Nblock; j++) {
133  for (int i = 0; i < Mblock; i++) {
134  int row = row_block + i;
135  int col = col_block + j;
136 
137  MatrixCoord coord = MatrixCoord(row, col);
138 
139  if (row < M && col < N) {
140 
141  tensor_c.at(coord) = convert_op(
142  alpha * ScalarType(accum[i][j]) +
143  beta * ScalarType(tensor_c.at(coord)));
144  }
145  }
146  }
147  }
148  }
149 }
150 
152 
157 template <
158  typename ElementA,
159  typename LayoutA,
160  typename ElementB,
161  typename LayoutB,
162  typename ElementC,
163  typename LayoutC,
164  typename ScalarType
165 >
167  gemm::GemmCoord problem_size,
168  ScalarType alpha,
170  ComplexTransform transform_a,
172  ComplexTransform transform_b,
173  ScalarType beta,
174  TensorRef<ElementC, LayoutC> tensor_c) {
175 
176  GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, ScalarType(0));
177 }
178 
180 
181 } // namespace host
182 } // namespace reference
183 } // namespace cutlass
Definition: aligned_buffer.h:35
ComplexTransform
Enumeraed type describing a transformation on a complex value.
Definition: complex.h:43
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: include/cutlass/gemm/gemm.h:94
Defines common types used for all GEMM-like operators.
CUTLASS_HOST_DEVICE Index const & n() const
Returns the GEMM N coordinate.
Definition: include/cutlass/gemm/gemm.h:137
Defines a structure containing strides and a pointer to tensor data.
CUTLASS_HOST_DEVICE Index const & k() const
Returns the GEMM K coordinate.
Definition: include/cutlass/gemm/gemm.h:145
void GemmComplex(gemm::GemmCoord problem_size, ScalarType alpha, TensorRef< ElementA, LayoutA > tensor_a, ComplexTransform transform_a, TensorRef< ElementB, LayoutB > tensor_b, ComplexTransform transform_b, ScalarType beta, TensorRef< ElementC, LayoutC > tensor_c, ComputeType initial_accum)
Definition: tools/util/include/cutlass/util/reference/host/gemm_complex.h:66
Boost-like numeric conversion operator for CUTLASS numeric types.
CUTLASS_HOST_DEVICE complex< T > conj(complex< T > const &z)
Returns the complex conjugate.
Definition: complex.h:356
Top-level include for all CUTLASS numeric types.
#define static_assert(__e, __m)
Definition: platform.h:153
CUTLASS_HOST_DEVICE Reference at(TensorCoord const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:307
CUTLASS_HOST_DEVICE Index const & m() const
Returns the GEMM M coordinate.
Definition: include/cutlass/gemm/gemm.h:129
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Definition: matrix_coord.h:39
Define basic numeric operators with specializations for Array<T, N>. SIMD-ize where possible...