Script.tcgen05.mma

Script.tcgen05.mma

Script.tcgen05.mma(a, b, d, enable_input_d, cta_group=1)[source]

Perform tensor core matrix multiply-accumulate with TMEM accumulator.

Computes d = a @ b + d (or d = a @ b when enable_input_d=False). All operands must be 2D: a is [M, K], b is [K, N], d is [M, N].

When cta_group=2, two CTAs collaborate on the MMA. Each CTA provides half the operands and holds half the accumulator:

  • A = [a0; a1] — A has shape (M, K), each CTA provides (M/2, K)

  • B = [b0, b1] — B has shape (K, N), each CTA provides (K, N/2)

  • D = [d0; d1] — D has shape (M, N), each CTA holds (M/2, N)

CTA0 is the CTA whose cluster rank has last bit 0, CTA1 is the other.

Parameters:
  • a (SharedTensor | TMemoryTensor) – Left-hand operand [M, K]. Can be in shared memory or tensor memory.

  • b (SharedTensor) – Right-hand operand [K, N]. Must be in shared memory.

  • d (TMemoryTensor) – Accumulator [M, N] in tensor memory. Used as both input and output.

  • enable_input_d (Expr | bool) – If True, computes d = a @ b + d. If False, computes d = a @ b.

  • cta_group (int) – CTA group size. 1 for single-CTA, 2 for two-CTA collaborative MMA.

Return type:

None

Notes

  • Thread group: Must be executed by a single warp (use self.single_warp()).

  • Hardware: Requires compute capability 10.0a+ (sm_100a).

  • PTX: tcgen05.mma