matmul#
Matrix Multiply (GEMM)#
matmul performs a transformation for Generic Matrix Multiplies (GEMMs) for complex and real-valued tensors. Batching is supported for any tensor with a rank higher than 2.
-
template<typename OpA, typename OpB>
__MATX_INLINE__ auto matx::matmul(const OpA &A, const OpB &B, float alpha = 1.0, float beta = 0.0)# Run a GEMM (generic matrix multiply))
Creates a new GEMM plan in the cache if none exists, and uses that to execute the GEMM. This function is preferred over creating a plan directly for both efficiency and simpler code. Since it only uses the signature of the GEMM to decide if a plan is cached, it may be able to reused plans for different A/B/C matrices as long as they were configured with the same dimensions.
- Template Parameters:
OpA – Data type of A tensor or operator
OpB – Data type of B tensor or operator
- Parameters:
A – A Tensor or Operator of shape
... x m x k
B – B Tensor or Operator of shape
... x k x n
alpha – Scalar multiplier to apply to operator A
beta – Scalar multiplier to apply to operator C on input
- Returns:
Operator that produces the output tensor C of shape
... x m x n
-
template<typename OpA, typename OpB>
__MATX_INLINE__ auto matx::matmul(const OpA &A, const OpB &B, const int32_t (&axis)[2], float alpha = 1.0, float beta = 0.0)# Run a GEMM (generic matrix multiply))
Creates a new GEMM plan in the cache if none exists, and uses that to execute the GEMM. This function is preferred over creating a plan directly for both efficiency and simpler code. Since it only uses the signature of the GEMM to decide if a plan is cached, it may be able to reused plans for different A/B/C matrices as long as they were configured with the same dimensions.
- Template Parameters:
OpA – Data type of A tensor or operator
OpB – Data type of B tensor or operator
- Parameters:
A – A Tensor or Operator of shape
... x m x k
B – B Tensor or Operator of shape
... x k x n
axis – the axis of the tensor or operator to perform the gemm along
alpha – Scalar multiplier to apply to operator A
beta – Scalar multiplier to apply to operator C on input
- Returns:
Operator that produces the output tensor C of shape
... x m x n
Examples#
// Perform the GEMM C = A*B
(c = matmul(a, b)).run(this->exec);
Permuted A
// Perform the GEMM C = A^T * B
auto at = a.PermuteMatrix();
(c = matmul(at, b)).run(this->exec);
Permuted B
// Perform the GEMM C = A * B^T
auto bt = b.PermuteMatrix();
(c = matmul(a, bt)).run(this->exec);
Batched
constexpr index_t batches = 5;
constexpr index_t m = 128;
constexpr index_t k = 256;
constexpr index_t n = 512;
tensor_t<TestType, 3> a{{batches, m, k}};
tensor_t<TestType, 3> b{{batches, k, n}};
tensor_t<TestType, 3> c{{batches, m, n}};
this->pb->template InitAndRunTVGenerator<TestType>(
"00_transforms", "matmul_operators", "run", {batches, m, k, n});
this->pb->NumpyToTensorView(a, "a");
this->pb->NumpyToTensorView(b, "b");
// Perform a batched gemm with "batches" GEMMs
(c = matmul(a, b)).run(this->exec);
Strided Batched
constexpr index_t batches = 16;
constexpr index_t m = 128;
constexpr index_t k = 256;
constexpr index_t n = 512;
tensor_t<TestType, 3> a{{batches, m, k}};
tensor_t<TestType, 3> b{{batches, k, n}};
tensor_t<TestType, 3> c{{batches, m, n}};
auto as = slice(a, {0, 0, 0}, {matxEnd, matxEnd, matxEnd}, {2, 1, 1});
auto bs = slice(b, {0, 0, 0}, {matxEnd, matxEnd, matxEnd}, {2, 1, 1});
tensor_t<TestType, 3> cs{{batches/2, m, n}};
this->pb->template InitAndRunTVGenerator<TestType>(
"00_transforms", "matmul_operators", "run", {batches, m, k, n});
this->pb->NumpyToTensorView(a, "a");
this->pb->NumpyToTensorView(b, "b");
// Perform a strided and batched GEMM where "as" and "bs" have a stride of 2 in their inner-most dimension
(cs = matmul(as, bs)).run(this->exec);
const int axis[2] = {2, 1};
cuda::std::array<int, 3> perm({0, 2, 1});
auto ai = make_tensor<TestType>({b, k, m});
auto bi = make_tensor<TestType>({b, n, k});
auto ci = make_tensor<TestType>({b, n, m});
auto ap = permute(ai, perm);
auto bp = permute(bi, perm);
auto cp = permute(ci, perm);
// copy data into permuted inputs
(ap = a3).run(this->exec);
(bp = b3).run(this->exec);
// Perform a GEMM with the last two dimensions permuted
(ci = matmul(ai, bi, axis)).run(this->exec);