Fix incorrect GPU assignment in multi gpu setup
Fix for #369
Checked out this PR - works fine on WSL2 Ubuntu 22.04. Thanks for the fix!
This fix works on my end for the small/default gpt2 model (2.1m tok/s on 8 GPU).
The current code still breaks for the gpt2-xl model, albeit in a different way, unrelated to openMPI (with 1 GPU or 8GPU). The error I'm getting is:
1373: void matmul_backward(floatX*, floatX*, floatX*, floatX*, floatX*, floatX*, float*, int, int, int, int): Assertion `(OC % OC_per_warp) == 0' failed.
Commit 6c179fa works fine for gpt2-xl. It is possible that a separate independent bug was introduced since that commit that only breaks the larger model.
The current code still breaks for the gpt2-xl model, albeit in a different way, unrelated to openMPI (with 1 GPU or 8GPU). The error I'm getting is:
1373: void matmul_backward(floatX*, floatX*, floatX*, floatX*, floatX*, floatX*, float*, int, int, int, int): Assertion `(OC % OC_per_warp) == 0' failed.
Gah, that bug is also my fault, I genuinely didn't think the hidden dimension would not be a multiple of 256 for any model we cared about! GPT2-XL is very much the exception here, where every single other GPT2 and GPT3 config (including GPT3-XL) meets that condition... 25 heads is such a weird design choice!
It's trivial to add a bounds check in the loop of "matmul_backward_bias_kernel7" but that would hurt performance... I think you can't just exit early for the out-of-bounds threads because of the reordering code at the bottom, and you need the __syncthreads() to run for every thread to avoid deadlock, so...
I don't have a machine to test anything on today unfortunately, but hopefully this (combined with commenting that assert) would work:
for (int k = 0; k < x128::size; k++) {
accumulators[k] = 0.0f;
}
int thread_id = threadIdx.y * block_size_x + threadIdx.x;
for (int idx = thread_id; idx < OC_per_warp; idx += block_size) {
shared[idx] = 0.0f;
}
__syncthreads();
if(global_oc < OC) { ///////////////// NEW LINE
for (int idx = blockIdx.y*block_size_y + threadIdx.y; idx < B * T; idx += gridDim.y*block_size_y) {
x128 packed_dout = load128(dout + global_oc + idx*OC);
for (int k = 0; k < x128::size; k++) {
accumulators[k] += (float)packed_dout[k];
}
}
// we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance
// so we accumulate in a conflict-free order, then reorder to match the global memory order
for (int k = 0; k < x128::size; k++) {
atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]);
}
} ///////////////// NEW LINE
if (threadIdx.y >= x128::size) { return; } // only need this many warps to reorder the data
__syncthreads();
// read the accumulated values in the conflict-free order
int i = threadIdx.x + (threadIdx.y * block_size_x);
float tmp = shared[i];
__syncthreads();
// write them back to shared memory in the global memory order
// 8-way bank conflict for BF16 x128, but only 8x per threadblock (rather than 8x per warp)
shared[local_oc + threadIdx.y] = tmp;
__syncthreads();
// now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp)
if (i + blockIdx.x*OC_per_warp < OC) { ///////////////// NEW LINE
atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]);
} ///////////////// NEW LINE
You'd also need to replace
const int grid_size_x = OC / OC_per_warp; // e.g. 3 horizontal blocks for 768 OCs at BF16
by:
const int grid_size_x = CEIL_DIV(OC, OC_per_warp); // e.g. 3 horizontal blocks for 768 OCs at BF16
I'll have time to try this on my side over the weekend if no one beats me to making a PR for it (feel free!)
Confirm this fixes the issue, stepping at ~460K tok/s on 4XA100 GPUs.
I've verified that the fix by ademeure works on my end for GPT2-XL and opened a PR.