mars-finetuning icon indicating copy to clipboard operation
mars-finetuning copied to clipboard

Add support for LayerNormalization

Open Daniel300000 opened this issue 1 year ago • 0 comments

Hi!

I would like to add support for LayerNormalization. Therefore, I thought about implementing the constraint as follows:

class LayerNormLipschitzConstraint(Constraint):
    def __init__(self, max_k, zero_gamma=None):
        self.max_k = max_k
        self.zero_gamma = zero_gamma

    def __call__(self, w):
        if self.zero_gamma is not None:
            t = w - self.zero_gamma
        else:
            t = w
        
        v = t * (1.0 / K.maximum(1.0, K.abs(t) / self.max_k))

        if self.zero_gamma is not None:
            return self.zero_gamma + v
        else:
            return v
        
    def get_config(self):
        return {"max_k": self.max_k}

I saw that BatchNormLipschitzConstraint also divides gamma by the standard deviation. However, I did not find this in the paper:

diag = w / K.sqrt(self.variance + 1e-6)

Also, LayerNormalization does not have a moving variance. Is the standard deviation necessary here?

Thank you!

Daniel300000 avatar May 25 '23 07:05 Daniel300000