mars-finetuning
mars-finetuning copied to clipboard
Add support for LayerNormalization
Hi!
I would like to add support for LayerNormalization. Therefore, I thought about implementing the constraint as follows:
class LayerNormLipschitzConstraint(Constraint):
def __init__(self, max_k, zero_gamma=None):
self.max_k = max_k
self.zero_gamma = zero_gamma
def __call__(self, w):
if self.zero_gamma is not None:
t = w - self.zero_gamma
else:
t = w
v = t * (1.0 / K.maximum(1.0, K.abs(t) / self.max_k))
if self.zero_gamma is not None:
return self.zero_gamma + v
else:
return v
def get_config(self):
return {"max_k": self.max_k}
I saw that BatchNormLipschitzConstraint also divides gamma by the standard deviation. However, I did not find this in the paper:
diag = w / K.sqrt(self.variance + 1e-6)
Also, LayerNormalization does not have a moving variance. Is the standard deviation necessary here?
Thank you!