matmul#

Matrix Multiply (GEMM)#

matmul performs a transformation for Generic Matrix Multiplies (GEMMs) for complex and real-valued tensors. Batching is supported for any tensor with a rank higher than 2.

Note

This function is currently is not supported with host-based executors (CPU)

template<typename OpA, typename OpB>
__MATX_INLINE__ auto matx::matmul(const OpA A, const OpB B, float alpha = 1.0, float beta = 0.0)#

Run a GEMM (generic matrix multiply))

Creates a new GEMM plan in the cache if none exists, and uses that to execute the GEMM. This function is preferred over creating a plan directly for both efficiency and simpler code. Since it only uses the signature of the GEMM to decide if a plan is cached, it may be able to reused plans for different A/B/C matrices as long as they were configured with the same dimensions.

Template Parameters:
  • OpA – Data type of A tensor or operator

  • OpB – Data type of B tensor or operator

Parameters:
  • A – A A Tensor or Operator

  • B – B B Tensor or Operator

  • alpha – Scalar multiplier to apply to operator A

  • beta – Scalar multiplier to apply to operator C on input

template<typename OpA, typename OpB>
__MATX_INLINE__ auto matx::matmul(const OpA A, const OpB B, const int32_t (&axis)[2], float alpha = 1.0, float beta = 0.0)#

Run a GEMM (generic matrix multiply))

Creates a new GEMM plan in the cache if none exists, and uses that to execute the GEMM. This function is preferred over creating a plan directly for both efficiency and simpler code. Since it only uses the signature of the GEMM to decide if a plan is cached, it may be able to reused plans for different A/B/C matrices as long as they were configured with the same dimensions.

Template Parameters:
  • OpA – Data type of A tensor or operator

  • OpB – Data type of B tensor or operator

Parameters:
  • A – A A Tensor or Operator

  • B – B B Tensor or Operator

  • axis – the axis of the tensor or operator to perform the gemm along

  • alpha – Scalar multiplier to apply to operator A

  • beta – Scalar multiplier to apply to operator C on input

Examples#

// Perform the GEMM C = A*B
(c = matmul(a, b)).run();

Permuted A

// Perform the GEMM C = A^T * B
auto at = a.PermuteMatrix();
(c = matmul(at, b)).run();

Permuted B

// Perform the GEMM C = A * B^T
auto bt = b.PermuteMatrix();
(c = matmul(a, bt)).run();

Batched

constexpr index_t batches = 5;
constexpr index_t m = 128;
constexpr index_t k = 256;
constexpr index_t n = 512;

tensor_t<TestType, 3> a{{batches, m, k}};
tensor_t<TestType, 3> b{{batches, k, n}};
tensor_t<TestType, 3> c{{batches, m, n}};  

this->pb->template InitAndRunTVGenerator<TestType>(
    "00_transforms", "matmul_operators", "run", {batches, m, k, n});

this->pb->NumpyToTensorView(a, "a");
this->pb->NumpyToTensorView(b, "b");

// Perform a batched gemm with "batches" GEMMs
(c = matmul(a, b)).run();

Strided Batched

constexpr index_t batches = 16;
constexpr index_t m = 128;
constexpr index_t k = 256;
constexpr index_t n = 512;

tensor_t<TestType, 3> a{{batches, m, k}};
tensor_t<TestType, 3> b{{batches, k, n}};
tensor_t<TestType, 3> c{{batches, m, n}};  

auto as = a.Slice({0, 0, 0}, {matxEnd, matxEnd, matxEnd}, {2, 1, 1});
auto bs = b.Slice({0, 0, 0}, {matxEnd, matxEnd, matxEnd}, {2, 1, 1});
tensor_t<TestType, 3> cs{{batches/2, m, n}};


this->pb->template InitAndRunTVGenerator<TestType>(
    "00_transforms", "matmul_operators", "run", {batches, m, k, n});

this->pb->NumpyToTensorView(a, "a");
this->pb->NumpyToTensorView(b, "b");

// Perform a strided and batched GEMM where "as" and "bs" have a stride of 2 in their inner-most dimension
(cs = matmul(as, bs)).run();
const int axis[2] = {2, 1};
std::array<int, 3> perm({0, 2, 1});

auto ai = make_tensor<TestType>({b, k, m});
auto bi = make_tensor<TestType>({b, n, k});
auto ci = make_tensor<TestType>({b, n, m});

auto ap = permute(ai, perm);
auto bp = permute(bi, perm);
auto cp = permute(ci, perm);

// copy data into permuted inputs
(ap = a3).run();
(bp = b3).run();

// Perform a GEMM with the last two dimensions permuted
(ci = matmul(ai, bi, axis)).run();