FLASHQuad_pytorch icon indicating copy to clipboard operation
FLASHQuad_pytorch copied to clipboard

关于GAU单元的问题

Open singaln opened this issue 2 years ago • 2 comments

您好,看了您关于GAU的代码,发现您的代码中并没有scale_offset的相关代码。 `def scale_offset(x): gamma = var(x.shape[−1:]) beta = var(x.shape[−1:]) return x ∗ gamma + beta

def attn(x, v, s=128): z = dense(x, s) q, k = scale_offset(z), scale_offset(z) qk = tf.einsum('bns,bms→bnm', q, k) a = relu(qk + rel_pos_bias(q, k)) ∗∗ 2 return tf.einsum('bnm,bme→bne', a, v)

def gated_attn_unit(x, d=768, e=1536): shortcut, x = x, norm(x) u, v = dense(x, e), dense(x, e) x = u ∗ attn(x, v) return dense(x, d) + shortcut`

singaln avatar Mar 15 '22 02:03 singaln

我是按照code6实现的这一部分,然后我感觉底下这部分代码就代表了scale offset,只不过是同时计算了 base = torch.einsum("...r,hr->...hr", base, self.weight) + self.bias

JunnYu avatar Mar 15 '22 03:03 JunnYu

我是按照code6实现的这一部分,然后我感觉底下这部分代码就代表了scale offset,只不过是同时计算了 base = torch.einsum("...r,hr->...hr", base, self.weight) + self.bias

OK,明白了!!!

还有一个小问题是您在flash进行预训练的时候,并没有在代码中看到Mixed chunk Attention分块混合的部分,而是直接采用GAU来进行训练的,这样对于预训练结果是不是会有一些影响呢?

singaln avatar Mar 16 '22 03:03 singaln