lightseq
lightseq copied to clipboard
层标准化核函数
请问层标准化的核函数 ker_layer_norm调用的grid_dim参数为何是batch_size与注释里写的不一致呢?好像注释的gridDim(batch_size * seq_len) 大小是对的。
template <typename T>
__forceinline__ __device__ T add_eps(T x) {
return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON);
}
/**
@brief: ker_layer_norm
Standard layer normalization.
It will not only output the layer norm result,
but also outputs variance.
may also output means, depends on whether
the means argument is nullptr
@thread
gridDim.x = batch_size * seq_len
blockDim.x = hidden_size
@param
ln_res: [batch_size* seq_len, hidden_size], ln result.
vars: [batch_size* seq_len], variance per token
means: [batch_size* seq_len], means per token, can be nullput
inp: [batch_size * seq_len, hidden_size], ln input.
scale: [hidden_size], ln scale
bias: [hidden_size], ln bias
*/
template <typename T>
__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp,
const T *scale, const T *bias, int hidden_size) {
// step 0. compute local sum
float l_sum = 0;
float l_square_sum = 0;
const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 val = inp_f4[idx];
l_sum += val.x + val.y + val.z + val.w;
l_square_sum +=
val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w;
}
// step 1. compute reduce sum
float mean_dim = float(hidden_size) * 4.f;
float reduce_val[2] = {l_sum, l_square_sum};
blockReduce<ReduceType::kSum, 2>(reduce_val);
__shared__ float s_mean, s_var;
if (threadIdx.x == 0) {
s_mean = reduce_val[0] / mean_dim;
if (means != nullptr) {
means[blockIdx.x] = s_mean;
}
s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON;
vars[blockIdx.x] = s_var;
s_var = rsqrtf(s_var);
}
__syncthreads();
// step 2. layer norm result
float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size;
for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float4 vscale = __ldg((const float4 *)scale + idx);
float4 vbias = __ldg((const float4 *)bias + idx);
float4 val = inp_f4[idx];
val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x;
val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y;
val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z;
val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w;
output_f4[idx] = val;
}
}
template <>
void launch_layer_norm<float>(float *ln_res, float *vars, float *means,
const float *inp, const float *scale,
const float *bias, int batch_size, int hidden_dim,
cudaStream_t stream) {
if (hidden_dim % 4 != 0) {
throw std::runtime_error("violate hidden_dim % 4 = 0");
}
hidden_dim >>= 2;
int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS);
dim3 grid_dim(batch_size);
dim3 block_dim(nthread);
ker_layer_norm<float><<<grid_dim, block_dim, 0, stream>>>(
ln_res, vars, means, inp, scale, bias, hidden_dim);
}
The size in the comment is correct, this variable name is a bit confusing.