Compute matrix multiplication where both matrix A and matrix B are stored in MXFP8 format. The equation below defines reference semantics for correctness; optimized kernels should decode/apply scales on-the-fly and avoid materializing full FP32 Adequant or Bdequant.
cij=ℓ=0∑K−1Adequant,iℓBdequant,jℓ.
Note: B is stored in row-major as N×K (i.e. Bdequant is N×K), so the multiplication is effectively C=AdequantBdequantT.
Input
- qa: MXFP8 payload bytes for matrix A of shape M×K (row-major)
- scalea: per-block E8M0 scale bytes for A, logical shape M×K/32
- qb: MXFP8 payload bytes for matrix B of shape N×K (row-major; transposed before multiply)
- scaleb: per-block E8M0 scale bytes for B, logical shape N×K/32
- M, N, K: matrix dimensions (K and N divisible by 32)
Output
- c: FP32 matrix of shape M×N where c=AdequantBdequantT
Notes
- Check out the MXFP8 format for more background.
- We use torch.scaled_mm as the reference implementation for correctness.
scaled_mm expects the second matrix in column-major layout; the reference therefore passes BdequantT (shape K×N) as the second argument so the result remains c=AdequantBdequant (logically unchanged).
- Scale tensors passed as scale_a / scale_b are assumed to already be laid out in the same swizzled blockwise format that
scaled_mm uses for MXFP8.
- You should treat these pointers as already-swizzled 32x4x4 layout scale storage and must not apply an additional swizzle.
Test Case Sizes
- 1024 x 1024 x 1024
- 2048 x 1024 x 2048
- 4096 x 2048 x 4096
- 4096 x 4096 x 4096
- 8192 x 4096 x 8192