Tensara Logo

tensara

Back

Solution: 2D Convolution via FFT (2nd place as of 8th May 2026)

nhauber99

·

May 8, 2026


Solution: 2D Convolution

I don't think this is how the problem was intended so be solved as using cuFFT is a boring and well known approach to this but given that the leaderboard already includes FFT based approaches here is one.

Problem: 2D Convolution

Submitted on: 08/05/2026, 17:12:43

  • Runtime: 615.80 μs
  • Performance: 37239.81 GFLOPS
  • Tests: 6/6
#include <cuda_runtime.h>
#include <cufft.h>
#include <dlfcn.h>
#include <stddef.h>

#ifndef PATCH_TILE
#define PATCH_TILE 512
#endif

#ifndef PATCH_BATCH
#define PATCH_BATCH 64
#endif

#ifndef PATCH_STREAMS
#define PATCH_STREAMS 2
#endif

// Set to 0 if Tensara mutates B in-place while reusing the same B pointer.
#ifndef CACHE_KERNEL_FFT
#define CACHE_KERNEL_FFT 1
#endif

struct CufftApi {
    void* lib = nullptr;

    using cufftPlanMany_t = cufftResult (*)(
        cufftHandle* plan,
        int rank,
        int* n,
        int* inembed,
        int istride,
        int idist,
        int* onembed,
        int ostride,
        int odist,
        cufftType type,
        int batch
    );

    using cufftExecR2C_t = cufftResult (*)(cufftHandle, cufftReal*, cufftComplex*);
    using cufftExecC2R_t = cufftResult (*)(cufftHandle, cufftComplex*, cufftReal*);
    using cufftDestroy_t = cufftResult (*)(cufftHandle);
    using cufftSetStream_t = cufftResult (*)(cufftHandle, cudaStream_t);

    cufftPlanMany_t planMany = nullptr;
    cufftExecR2C_t execR2C = nullptr;
    cufftExecC2R_t execC2R = nullptr;
    cufftDestroy_t destroy = nullptr;
    cufftSetStream_t setStream = nullptr;

    bool ok = false;
};

template <typename T>
static T load_sym(void* lib, const char* name) {
    return reinterpret_cast<T>(dlsym(lib, name));
}

static CufftApi& get_cufft() {
    static CufftApi api;
    static bool initialized = false;

    if (initialized) return api;
    initialized = true;

    const char* names[] = {
        "libcufft.so",
        "libcufft.so.12",
        "libcufft.so.11"
    };

    for (const char* name : names) {
        api.lib = dlopen(name, RTLD_NOW | RTLD_LOCAL);
        if (api.lib) break;
    }

    if (!api.lib) return api;

    api.planMany = load_sym<CufftApi::cufftPlanMany_t>(api.lib, "cufftPlanMany");
    api.execR2C = load_sym<CufftApi::cufftExecR2C_t>(api.lib, "cufftExecR2C");
    api.execC2R = load_sym<CufftApi::cufftExecC2R_t>(api.lib, "cufftExecC2R");
    api.destroy = load_sym<CufftApi::cufftDestroy_t>(api.lib, "cufftDestroy");
    api.setStream = load_sym<CufftApi::cufftSetStream_t>(api.lib, "cufftSetStream");

    api.ok =
        api.planMany &&
        api.execR2C &&
        api.execC2R &&
        api.destroy &&
        api.setStream;

    return api;
}

static int is_fast_fft_size(int n) {
    int m = n;
    const int factors[] = {2, 3, 5, 7};

    for (int i = 0; i < 4; ++i) {
        int p = factors[i];
        while (m % p == 0) m /= p;
    }

    return m == 1;
}

static int next_fast_fft_size(int n) {
    while (!is_fast_fft_size(n)) ++n;
    return n;
}

static int div_up_int(int a, int b) {
    return (a + b - 1) / b;
}

static int min_int(int a, int b) {
    return a < b ? a : b;
}

static int grid_1d(size_t n) {
    return int((n + 255) / 256);
}

struct StreamSlot {
    cudaStream_t stream = nullptr;

    float* Awork = nullptr;

    cufftHandle r2c = 0;
    cufftHandle c2r = 0;

    bool has_r2c = false;
    bool has_c2r = false;
};

struct PatchFftCache {
    int H = 0;
    int W = 0;
    int Kh = 0;
    int Kw = 0;

    int tiles_y = 0;
    int tiles_x = 0;
    int num_patches = 0;

    int P = 0;
    int Q = 0;
    int Qfreq = 0;
    int pitch = 0;

    int batch_capacity = 0;
    int active_streams = 0;

    size_t work_elems = 0;
    size_t freq_elems = 0;
    size_t work_bytes = 0;

    StreamSlot slots[PATCH_STREAMS];

    cudaStream_t kernel_stream = nullptr;
    float* Bwork = nullptr;

    cufftHandle b_r2c = 0;
    bool has_b_r2c = false;

    const float* cached_B = nullptr;
    bool kernel_fft_ready = false;

    bool ready = false;
};

static bool make_r2c_plan(
    CufftApi& cufft,
    cufftHandle* plan,
    int P,
    int Q,
    int Qfreq,
    int pitch,
    int batch
) {
    int n[2] = {P, Q};

    int inembed[2] = {P, pitch};
    int onembed[2] = {P, Qfreq};

    int idist = P * pitch;
    int odist = P * Qfreq;

    cufftResult res = cufft.planMany(
        plan,
        2,
        n,
        inembed,
        1,
        idist,
        onembed,
        1,
        odist,
        CUFFT_R2C,
        batch
    );

    return res == CUFFT_SUCCESS;
}

static bool make_c2r_plan(
    CufftApi& cufft,
    cufftHandle* plan,
    int P,
    int Q,
    int Qfreq,
    int pitch,
    int batch
) {
    int n[2] = {P, Q};

    int inembed[2] = {P, Qfreq};
    int onembed[2] = {P, pitch};

    int idist = P * Qfreq;
    int odist = P * pitch;

    cufftResult res = cufft.planMany(
        plan,
        2,
        n,
        inembed,
        1,
        idist,
        onembed,
        1,
        odist,
        CUFFT_C2R,
        batch
    );

    return res == CUFFT_SUCCESS;
}

static void release_cache(PatchFftCache& cache, CufftApi& cufft) {
    cudaDeviceSynchronize();

    for (int s = 0; s < PATCH_STREAMS; ++s) {
        StreamSlot& slot = cache.slots[s];

        if (slot.has_r2c) cufft.destroy(slot.r2c);
        if (slot.has_c2r) cufft.destroy(slot.c2r);

        if (slot.Awork) cudaFree(slot.Awork);
        if (slot.stream) cudaStreamDestroy(slot.stream);
    }

    if (cache.has_b_r2c) cufft.destroy(cache.b_r2c);
    if (cache.Bwork) cudaFree(cache.Bwork);
    if (cache.kernel_stream) cudaStreamDestroy(cache.kernel_stream);

    cache = PatchFftCache();
}

static bool ensure_cache(
    PatchFftCache& cache,
    CufftApi& cufft,
    int H,
    int W,
    int Kh,
    int Kw
) {
    if (cache.ready &&
        cache.H == H &&
        cache.W == W &&
        cache.Kh == Kh &&
        cache.Kw == Kw) {
        return true;
    }

    release_cache(cache, cufft);

    cache.H = H;
    cache.W = W;
    cache.Kh = Kh;
    cache.Kw = Kw;

    cache.tiles_y = div_up_int(H, PATCH_TILE);
    cache.tiles_x = div_up_int(W, PATCH_TILE);
    cache.num_patches = cache.tiles_y * cache.tiles_x;

    // Per-patch FFT dimensions.
    // We compute 512x512 outputs from a 512+K-1 haloed input patch.
    cache.P = next_fast_fft_size(PATCH_TILE + Kh - 1);
    cache.Q = next_fast_fft_size(PATCH_TILE + Kw - 1);

    cache.Qfreq = cache.Q / 2 + 1;

    // In-place R2C physical real pitch.
    cache.pitch = 2 * cache.Qfreq;

    cache.work_elems = size_t(cache.P) * cache.pitch;
    cache.freq_elems = size_t(cache.P) * cache.Qfreq;
    cache.work_bytes = cache.work_elems * sizeof(float);

    cache.batch_capacity = min_int(PATCH_BATCH, cache.num_patches);

    int chunks = div_up_int(cache.num_patches, cache.batch_capacity);
    cache.active_streams = min_int(PATCH_STREAMS, chunks);

    if (cudaStreamCreateWithFlags(
            &cache.kernel_stream,
            cudaStreamNonBlocking
        ) != cudaSuccess) {
        release_cache(cache, cufft);
        return false;
    }

    if (cudaMalloc(&cache.Bwork, cache.work_bytes) != cudaSuccess) {
        release_cache(cache, cufft);
        return false;
    }

    if (!make_r2c_plan(
            cufft,
            &cache.b_r2c,
            cache.P,
            cache.Q,
            cache.Qfreq,
            cache.pitch,
            1
        )) {
        release_cache(cache, cufft);
        return false;
    }

    cache.has_b_r2c = true;

    if (cufft.setStream(cache.b_r2c, cache.kernel_stream) != CUFFT_SUCCESS) {
        release_cache(cache, cufft);
        return false;
    }

    for (int s = 0; s < cache.active_streams; ++s) {
        StreamSlot& slot = cache.slots[s];

        if (cudaStreamCreateWithFlags(
                &slot.stream,
                cudaStreamNonBlocking
            ) != cudaSuccess) {
            release_cache(cache, cufft);
            return false;
        }

        size_t bytes = size_t(cache.batch_capacity) * cache.work_bytes;

        if (cudaMalloc(&slot.Awork, bytes) != cudaSuccess) {
            release_cache(cache, cufft);
            return false;
        }

        if (!make_r2c_plan(
                cufft,
                &slot.r2c,
                cache.P,
                cache.Q,
                cache.Qfreq,
                cache.pitch,
                cache.batch_capacity
            )) {
            release_cache(cache, cufft);
            return false;
        }

        slot.has_r2c = true;

        if (!make_c2r_plan(
                cufft,
                &slot.c2r,
                cache.P,
                cache.Q,
                cache.Qfreq,
                cache.pitch,
                cache.batch_capacity
            )) {
            release_cache(cache, cufft);
            return false;
        }

        slot.has_c2r = true;

        if (cufft.setStream(slot.r2c, slot.stream) != CUFFT_SUCCESS) {
            release_cache(cache, cufft);
            return false;
        }

        if (cufft.setStream(slot.c2r, slot.stream) != CUFFT_SUCCESS) {
            release_cache(cache, cufft);
            return false;
        }
    }

    cache.cached_B = nullptr;
    cache.kernel_fft_ready = false;
    cache.ready = true;

    return true;
}

__global__ void place_kernel_fft_kernel(
    const float* __restrict__ B,
    float* __restrict__ Bwork,
    int Kh,
    int Kw,
    int P,
    int Q,
    int pitch
) {
    size_t idx = size_t(blockIdx.x) * blockDim.x + threadIdx.x;
    size_t total = size_t(Kh) * Kw;

    if (idx >= total) return;

    int ky = idx / Kw;
    int kx = idx - ky * Kw;

    // Patch input starts at output_tile_origin - kernel_center.
    // Therefore local output [0,0] needs:
    // sum B[ky,kx] * patch_input[ky,kx]
    //
    // So the FFT convolution kernel uses offset -ky, -kx.
    int dy = ky == 0 ? 0 : P - ky;
    int dx = kx == 0 ? 0 : Q - kx;

    Bwork[size_t(dy) * pitch + dx] = B[idx];
}

__global__ void load_A_patches_kernel(
    const float* __restrict__ A,
    float* __restrict__ Awork,
    int H,
    int W,
    int Kh,
    int Kw,
    int tiles_x,
    int first_patch,
    int actual_batch,
    int P,
    int Q,
    int pitch,
    size_t work_elems
) {
    size_t idx = size_t(blockIdx.x) * blockDim.x + threadIdx.x;
    size_t total = size_t(actual_batch) * work_elems;

    if (idx >= total) return;

    int local_patch = int(idx / work_elems);
    size_t t = idx - size_t(local_patch) * work_elems;

    int r = int(t / pitch);
    int c = int(t - size_t(r) * pitch);

    int patch_id = first_patch + local_patch;

    int py = patch_id / tiles_x;
    int px = patch_id - py * tiles_x;

    int cy = Kh / 2;
    int cx = Kw / 2;

    int patch_base_y = py * PATCH_TILE - cy;
    int patch_base_x = px * PATCH_TILE - cx;

    int logical_h = PATCH_TILE + Kh - 1;
    int logical_w = PATCH_TILE + Kw - 1;

    float v = 0.0f;

    if (r < logical_h && c < logical_w && c < Q) {
        int gy = patch_base_y + r;
        int gx = patch_base_x + c;

        if ((unsigned)gy < (unsigned)H && (unsigned)gx < (unsigned)W) {
            v = A[size_t(gy) * W + gx];
        }
    }

    Awork[idx] = v;
}

__global__ void multiply_patch_spectra_kernel(
    cufftComplex* __restrict__ Afreq,
    const cufftComplex* __restrict__ Bfreq,
    size_t freq_elems,
    int actual_batch
) {
    size_t idx = size_t(blockIdx.x) * blockDim.x + threadIdx.x;
    size_t total = size_t(actual_batch) * freq_elems;

    if (idx >= total) return;

    size_t f = idx % freq_elems;

    cufftComplex a = Afreq[idx];
    cufftComplex b = Bfreq[f];

    cufftComplex out;
    out.x = a.x * b.x - a.y * b.y;
    out.y = a.x * b.y + a.y * b.x;

    Afreq[idx] = out;
}

__global__ void scatter_patches_kernel(
    const float* __restrict__ Awork,
    float* __restrict__ C,
    int H,
    int W,
    int tiles_x,
    int first_patch,
    int actual_batch,
    int pitch,
    size_t work_elems,
    float scale
) {
    size_t tile_elems = size_t(PATCH_TILE) * PATCH_TILE;

    size_t idx = size_t(blockIdx.x) * blockDim.x + threadIdx.x;
    size_t total = size_t(actual_batch) * tile_elems;

    if (idx >= total) return;

    int local_patch = int(idx / tile_elems);
    size_t t = idx - size_t(local_patch) * tile_elems;

    int oy = int(t / PATCH_TILE);
    int ox = int(t - size_t(oy) * PATCH_TILE);

    int patch_id = first_patch + local_patch;

    int py = patch_id / tiles_x;
    int px = patch_id - py * tiles_x;

    int gy = py * PATCH_TILE + oy;
    int gx = px * PATCH_TILE + ox;

    if ((unsigned)gy >= (unsigned)H || (unsigned)gx >= (unsigned)W) {
        return;
    }

    float v = Awork[
        size_t(local_patch) * work_elems +
        size_t(oy) * pitch +
        ox
    ];

    C[size_t(gy) * W + gx] = v * scale;
}

static bool prepare_kernel_fft(
    PatchFftCache& cache,
    CufftApi& cufft,
    const float* B
) {
#if CACHE_KERNEL_FFT
    if (cache.kernel_fft_ready && cache.cached_B == B) {
        return true;
    }
#endif

    if (cudaMemsetAsync(
            cache.Bwork,
            0,
            cache.work_bytes,
            cache.kernel_stream
        ) != cudaSuccess) {
        return false;
    }

    place_kernel_fft_kernel<<<
        grid_1d(size_t(cache.Kh) * cache.Kw),
        256,
        0,
        cache.kernel_stream
    >>>(
        B,
        cache.Bwork,
        cache.Kh,
        cache.Kw,
        cache.P,
        cache.Q,
        cache.pitch
    );

    if (cufft.execR2C(
            cache.b_r2c,
            cache.Bwork,
            reinterpret_cast<cufftComplex*>(cache.Bwork)
        ) != CUFFT_SUCCESS) {
        return false;
    }

    if (cudaStreamSynchronize(cache.kernel_stream) != cudaSuccess) {
        return false;
    }

    cache.cached_B = B;
    cache.kernel_fft_ready = true;

    return true;
}

static void zero_output(float* C, int H, int W) {
    cudaMemset(C, 0, size_t(H) * W * sizeof(float));
}

// Note: A, B, C are device pointers
extern "C" void solution(
    const float* A,
    const float* B,
    float* C,
    size_t H,
    size_t W,
    size_t Kh,
    size_t Kw
) {
    int h = int(H);
    int w = int(W);
    int kh = int(Kh);
    int kw = int(Kw);

    CufftApi& cufft = get_cufft();

    if (!cufft.ok) {
        zero_output(C, h, w);
        return;
    }

    static PatchFftCache cache;

    if (!ensure_cache(cache, cufft, h, w, kh, kw)) {
        zero_output(C, h, w);
        return;
    }

    if (!prepare_kernel_fft(cache, cufft, B)) {
        zero_output(C, h, w);
        return;
    }

    float scale = 1.0f / float(size_t(cache.P) * cache.Q);

    int wave_stride = cache.batch_capacity * cache.active_streams;

    for (int base = 0; base < cache.num_patches; base += wave_stride) {
        for (int s = 0; s < cache.active_streams; ++s) {
            int first_patch = base + s * cache.batch_capacity;

            if (first_patch >= cache.num_patches) {
                continue;
            }

            int actual_batch = min_int(
                cache.batch_capacity,
                cache.num_patches - first_patch
            );

            StreamSlot& slot = cache.slots[s];

            load_A_patches_kernel<<<
                grid_1d(size_t(actual_batch) * cache.work_elems),
                256,
                0,
                slot.stream
            >>>(
                A,
                slot.Awork,
                h,
                w,
                kh,
                kw,
                cache.tiles_x,
                first_patch,
                actual_batch,
                cache.P,
                cache.Q,
                cache.pitch,
                cache.work_elems
            );

            cufft.execR2C(
                slot.r2c,
                slot.Awork,
                reinterpret_cast<cufftComplex*>(slot.Awork)
            );

            multiply_patch_spectra_kernel<<<
                grid_1d(size_t(actual_batch) * cache.freq_elems),
                256,
                0,
                slot.stream
            >>>(
                reinterpret_cast<cufftComplex*>(slot.Awork),
                reinterpret_cast<const cufftComplex*>(cache.Bwork),
                cache.freq_elems,
                actual_batch
            );

            cufft.execC2R(
                slot.c2r,
                reinterpret_cast<cufftComplex*>(slot.Awork),
                slot.Awork
            );

            scatter_patches_kernel<<<
                grid_1d(size_t(actual_batch) * PATCH_TILE * PATCH_TILE),
                256,
                0,
                slot.stream
            >>>(
                slot.Awork,
                C,
                h,
                w,
                cache.tiles_x,
                first_patch,
                actual_batch,
                cache.pitch,
                cache.work_elems,
                scale
            );
        }
    }

    cudaDeviceSynchronize();
}

Comments