Tensara Logo

tensara

All Problems

MXFP8 GEMM

HARD

Compute matrix multiplication where both matrix AA and matrix BB are stored in MXFP8 format. The equation below defines reference semantics for correctness; optimized kernels should decode/apply scales on-the-fly and avoid materializing full FP32 AdequantA_{\mathrm{dequant}} or BdequantB_{\mathrm{dequant}}.

cij==0K1Adequant,iBdequant,j.c_{ij} = \sum_{\ell=0}^{K-1} A_{\mathrm{dequant},i\ell} \, B_{\mathrm{dequant},j\ell}.

Note: BB is stored in row-major as N×KN \times K (i.e. BdequantB_{\mathrm{dequant}} is N×KN \times K), so the multiplication is effectively C=AdequantBdequantTC = A_{\mathrm{dequant}} \, B_{\mathrm{dequant}}^T.

Input

  • qaq_a: MXFP8 payload bytes for matrix AA of shape M×KM \times K (row-major)
  • scaleascale_a: per-block E8M0 scale bytes for AA, logical shape M×K/32M \times K/32
  • qbq_b: MXFP8 payload bytes for matrix BB of shape N×KN \times K (row-major; transposed before multiply)
  • scalebscale_b: per-block E8M0 scale bytes for BB, logical shape N×K/32N \times K/32
  • MM, NN, KK: matrix dimensions (KK and NN divisible by 32)

Output

  • cc: FP32 matrix of shape M×NM \times N where c=AdequantBdequantTc = A_{\mathrm{dequant}} \, B_{\mathrm{dequant}}^T

Notes

  • Check out the MXFP8 format for more background.
  • We use torch.scaled_mm as the reference implementation for correctness. scaled_mm expects the second matrix in column-major layout; the reference therefore passes BdequantTB_{\mathrm{dequant}}^T (shape K×NK \times N) as the second argument so the result remains c=AdequantBdequantc = A_{\mathrm{dequant}} B_{\mathrm{dequant}} (logically unchanged).
  • Scale tensors passed as scale_ascale\_a / scale_bscale\_b are assumed to already be laid out in the same swizzled blockwise format that scaled_mm uses for MXFP8.
  • You should treat these pointers as already-swizzled 32x4x4 layout scale storage and must not apply an additional swizzle.

Test Case Sizes

  • 1024 x 1024 x 1024
  • 2048 x 1024 x 2048
  • 4096 x 2048 x 4096
  • 4096 x 4096 x 4096
  • 8192 x 4096 x 8192
Console

Sample Run Results

Hit "Run" to test your code with sample inputs

Loading...

Loading editor...

CUDA C++ environment

Desktop Required for Code Submission

For the best coding experience, please switch to a desktop device to write and submit your solution.