llm.c icon indicating copy to clipboard operation
llm.c copied to clipboard

7-8% speedup: optimize matmul_backward_bias_kernel, reduce cast ops, improve loop unrolling, direct var use

Open bgorlick opened this issue 8 months ago • 2 comments

optimization in matmul_backward_bias_kernel - measurable 7-8% speedup in the kernel in isolated benchmarking.

summary

optimizes the matmul_backward_bias_kernel, resulting in an average speedup of 7-8%.

changes include:

reduced cast operations
improved loop unrolling
direct variable usage
fully profiled and gpt2_cu tested
benchmarked the algorithms independently of gpt2 over million+ iter to confirm performance
tensor max diff: 6.426e-08, rel error: 1.215e-05

profiling summary

single gpu (nvidia rtx a6000)
batch_size: 4, seq_len: 64, zero optimization disabled
logit max diff: 0.000847
tensor checks: max diff: 1.252e-05, rel error: 8.960e-06
step-by-step loss improvements confirmed, consistent execution times

benchmark script

the benchmark.cu script used for profiling and verifying performance improvements is included below for reproducibility and further testing.

it was tricky to properly fully benchmark the kernel, even after profiling with nsys because the kernel is part of the larger project so this benchmark script compares the original and submission in isolation. While this is not a true full representation of the kernel's performance in the larger project, it does give a good idea of the performance improvements.


/*
 * Benchmark Script for matmul_backward_bias_kernel Optimization
 * For gpt2_cu in llm.c project found here: https://github.com/karpathy/llm.c
 *
 * This script benchmarks the optimized and original versions of the matmul_backward_bias_kernel
 * to verify the performance improvements. 

 * It is tricky to fully benchmark the kernel because its a part of a larger project so
 * I had to create a separate benchmark script to test the kernel in isolation. While this is not a true
 * representation of the kernel's performance in the larger project, it does give a good idea of the
 * performance improvements.

 * I performed benchmarks using an A6000 so the results may vary on other GPUs.
 * 
 * Key Features:
 * - Fully profiled and tested with gpt2_cu
 * - Measures performance improvements and ensures correctness of results
 * 
 * Usage:
 * Compile with:
 * nvcc -o bench bench.cu -arch=sm_86 -O3    //  use the appropriate sm_xx architecture for your GPU 
 * Run with:
 * ./bench <kernel_version> -b <iterations>
 * 
 * kernel_version: 0 for the optimized kernel, 1 for the original kernel
 * 
 * Copyright (c) 2024 Benjamin Gorlick
 * Licensed under the MIT License
 * github.com/bgorlick/ - Benchmark part of llm.c project
 */


#include <cuda_runtime.h>
#include <iostream>
#include <chrono>
#include <assert.h>
#include "llmc/cuda_utils.cuh"  //  this includes definitions for x128, floatX, and any other required types/macros

#ifndef WARP_SIZE
#define WARP_SIZE 32
#endif

/// workin on this one
template<typename OutFloat, bool UseAuxBuffer>
__global__ void matmul_backward_bias_kernel9(OutFloat* dbias, const floatX* dout, int B, int T, int OC,
                                                std::bool_constant<UseAuxBuffer>) {

    static_assert(WARP_SIZE % 4 == 0, "WARP_SIZE must be a multiple of 4");
    constexpr const int bdx = 4;
    constexpr const int bdy = WARP_SIZE / bdx;
    assert(blockDim.x == bdx && blockDim.y == bdy);

    int warp_d = threadIdx.x, warp_c = threadIdx.y, block_d = threadIdx.z;
    int global_oc = blockIdx.x * bdy * x128::size + warp_c * x128::size; // 64 OCs at BF16

    if (global_oc < OC) {
        int bt_per_block = bdx * blockDim.z;
        float accumulators[x128::size] = {0};
        int idx = blockIdx.y * bt_per_block + warp_d + bdx * block_d;
        x128 packed_dout;
        float* packed_dout_float = reinterpret_cast<float*>(&packed_dout);

        // sum up over all bt within registers
        #pragma unroll
        for (; idx < B * T; idx += gridDim.y * bt_per_block) {
            packed_dout = load128(dout + global_oc + idx * OC);
            
            #pragma unroll
            for (int k = 0; k < x128::size; k++) {
                accumulators[k] += packed_dout_float[k];
            }
        }
        __shared__ float sub_results[x128::size][WARP_SIZE][bdy];

        #pragma unroll
        for (int k = 0; k < x128::size; k++) {
            float v = accumulators[k];
            v += __shfl_down_sync(0xffffffff, v, 1, 4);
            v += __shfl_down_sync(0xffffffff, v, 2, 4);
            if (warp_d == 0) {
                sub_results[k][block_d][warp_c] = v;
            }
        }
        __syncthreads();

        #pragma unroll 8
        for (int k = block_d; k < x128::size; k += blockDim.z) {
            float a = 0.f;

            #pragma unroll
            for (int r = warp_d; r < blockDim.z; r += bdx) {
                float v = sub_results[k][r][warp_c];
                v += __shfl_down_sync(0xffffffff, v, 1, 4);
                v += __shfl_down_sync(0xffffffff, v, 2, 4);
                a += v;
            }
            
            if (warp_d == 0) {
                if constexpr (!UseAuxBuffer) {
                    dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]);
                } else {
                    dbias[global_oc + k + blockIdx.y * OC] = a;
                }
            }
        }
    }
}
                   
// reference (original code)
template<typename OutFloat, bool UseAuxBuffer>
__global__ void matmul_backward_bias_kernel_v2(OutFloat* dbias, const floatX* dout, int B, int T, int OC,
                                               std::bool_constant<UseAuxBuffer>) {
    constexpr const int bdx = 4;
    constexpr const int bdy = WARP_SIZE / bdx;
    assert(blockDim.x == bdx);
    assert(blockDim.y == bdy);

    int warp_d = (int)threadIdx.x;
    int warp_c = (int)threadIdx.y;
    int block_d = (int)threadIdx.z;

    const int OC_per_warp = bdy * x128::size;

    int local_oc = warp_c * x128::size;
    int global_oc = blockIdx.x * OC_per_warp + local_oc;

    int local_bt = warp_d + bdx * block_d;
    int bt_per_block = bdx * blockDim.z;

    float accumulators[x128::size];
    for (int k = 0; k < x128::size; k++) {
        accumulators[k] = 0.0f;
    }

    if(global_oc < OC) {
        for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) {
            x128 packed_dout = load128(dout + global_oc + idx*OC);
            for (int k = 0; k < x128::size; k++) {
                accumulators[k] += (float)packed_dout[k];
            }
        }
    }

    __shared__ float sub_results[x128::size][WARP_SIZE][bdy];

    for (int k = 0; k < x128::size; k++) {
        float v = accumulators[k];
        v += __shfl_down_sync(0xffffffff, v, 1, 4);
        v += __shfl_down_sync(0xffffffff, v, 2, 4);
        if(warp_d == 0) {
            sub_results[k][block_d][warp_c] = v;
        }
    }
    __syncthreads();

    for (int k = block_d; k < x128::size; k += blockDim.z) {
        float a = 0.f;
        for (int r = warp_d; r < blockDim.z; r += bdx) {
            float v = sub_results[k][r][warp_c];
            v += __shfl_down_sync(0xffffffff, v, 1, 4);
            v += __shfl_down_sync(0xffffffff, v, 2, 4);
            a += v;
        }
        if(warp_d == 0 && global_oc < OC) {
            if constexpr (!UseAuxBuffer) {
                dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]);
            } else {
                dbias[global_oc + k + blockIdx.y * OC] = a;
            }
        }
    }
}

void run_benchmark_v1(int iterations, int B, int T, int OC, floatX* dout, float* dbias) {
    for (int i = 0; i < iterations; ++i) {
        matmul_backward_bias_kernel9<<<dim3((OC + 63) / 64, (B * T + 15) / 16), dim3(4, 8, 4)>>>(dbias, dout, B, T, OC, std::false_type{});
        cudaDeviceSynchronize();
    }
}

void run_benchmark_v2(int iterations, int B, int T, int OC, floatX* dout, float* dbias) {
    for (int i = 0; i < iterations; ++i) {
        matmul_backward_bias_kernel_v2<<<dim3((OC + 63) / 64, (B * T + 15) / 16), dim3(4, 8, 4)>>>(dbias, dout, B, T, OC, std::false_type{});
        cudaDeviceSynchronize();
    }
}

int main(int argc, char* argv[]) {
    if (argc < 4) {
        std::cerr << "Usage: " << argv[0] << " <kernel_version> -b <iterations>\n";
        return 1;
    }

    int kernel_version = std::atoi(argv[1]);
    int iterations = std::atoi(argv[3]);

    int B = 256;  
    int T = 128;  
    int OC = 1024;  

    floatX* dout;
    float* dbias;
    cudaMalloc(&dout, B * T * OC * sizeof(floatX));
    cudaMalloc(&dbias, OC * sizeof(float));

    auto start = std::chrono::high_resolution_clock::now();

    if (kernel_version == 0) {
        run_benchmark_v1(iterations, B, T, OC, dout, dbias);
    } else if (kernel_version == 1) {
        run_benchmark_v2(iterations, B, T, OC, dout, dbias);
    } else {
        std::cerr << "Invalid kernel version: " << kernel_version << "\n";
        return 1;
    }

    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> elapsed = end - start;

    std::cout << "Elapsed time: " << elapsed.count() << " seconds\n";

    cudaFree(dout);
    cudaFree(dbias);

    return 0;
}

bgorlick avatar Jun 19 '24 12:06 bgorlick