Script.tcgen05.mma¶
- Script.tcgen05.mma(a, b, d, enable_input_d, cta_group=1)[source]¶
Perform tensor core matrix multiply-accumulate with TMEM accumulator.
Computes
d = a @ b + d(ord = a @ bwhenenable_input_d=False). All operands must be 2D:ais[M, K],bis[K, N],dis[M, N].When
cta_group=2, two CTAs collaborate on the MMA. Each CTA provides half the operands and holds half the accumulator:A = [a0; a1]— A has shape (M, K), each CTA provides (M/2, K)B = [b0, b1]— B has shape (K, N), each CTA provides (K, N/2)D = [d0; d1]— D has shape (M, N), each CTA holds (M/2, N)
CTA0 is the CTA whose cluster rank has last bit 0, CTA1 is the other.
- Parameters:
a (SharedTensor | TMemoryTensor) – Left-hand operand
[M, K]. Can be in shared memory or tensor memory.b (SharedTensor) – Right-hand operand
[K, N]. Must be in shared memory.d (TMemoryTensor) – Accumulator
[M, N]in tensor memory. Used as both input and output.enable_input_d (Expr | bool) – If
True, computesd = a @ b + d. IfFalse, computesd = a @ b.cta_group (int) – CTA group size. 1 for single-CTA, 2 for two-CTA collaborative MMA.
- Return type:
None
Notes
Thread group: Must be executed by a single warp (use
self.single_warp()).Hardware: Requires compute capability 10.0a+ (sm_100a).
PTX:
tcgen05.mma