llm.c
llm.c copied to clipboard
Online Softmax is wrong
You should change
__global__ void softmax_forward_online_kernel8(float* out, const float* inp, int N, int C) {
const int warpsPerBlock = blockDim.x / warpSize;
int tid = threadIdx.x;
if (tid >= C) {
return;
}
int warpId = tid / warpSize;
int laneId = tid % warpSize;
// one warp one row
int row = blockIdx.x * warpsPerBlock + warpId;
if (row >= N) {
return;
}
Into
const int warpsPerBlock = blockDim.x / warpSize;
int tid = threadIdx.x;
int warpId = tid / warpSize;
int laneId = tid % warpSize;
// one warp one row
int row = blockIdx.x * warpsPerBlock + warpId;
if (laneId >= C) {
return;
}
if (row >= N) {
return;
}