Matrix Vector Multiplication
Triton
B200
Error
TEST CASES
0/0passed
Error Details (Unknown error)
Traceback (most recent call last):
File "/usr/local/lib/python3.11/site-packages/triton/language/core.py", line 34, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/triton/language/core.py", line 1814, in dot
return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/triton/language/semantic.py", line 1571, in dot
and rhs.shape[-1].value >= min_dot_size[1], \
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 16
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/runner.py", line 124, in run_checker
solution_func(*(list(input_tensors) + [actual_output] + list(extra_params)))
File "/tmp/tmpyi4yvmri/triton_solution.py", line 56, in solution
matmul_solution(input_a, input_b, output_c, m, 1, k)
File "/tmp/tmpyi4yvmri/triton_solution.py", line 46, in matmul_solution
_matmul_kernel[grid](
File "/usr/local/lib/python3.11/site-packages/triton/runtime/jit.py", line 347, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/triton/runtime/jit.py", line 569, in run
kernel = self.compile(src, target=target, options=options.__dict__)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/triton/compiler/compiler.py", line 278, in compile
module = src.make_ir(options, codegen_fns, module_map, context)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.11/site-packages/triton/compiler/compiler.py", line 81, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
triton.compiler.errors.CompilationError: at 26:15:
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
offs_k = k + tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b, allow_tf32=False)
^
Submitted Code