segment-anything icon indicating copy to clipboard operation
segment-anything copied to clipboard

Question about computational improvement of LayerNorm2d

Open 072jiajia opened this issue 1 year ago • 0 comments

Thank you for your excellent work.

I would like to inquire whether a certain modification could potentially enhance the speed of your model's computation. The LayerNorm2d involves the operation [(x - u) / s] * w + b. If we perform this operation by combining the /s and *w into *(w/s), it will require only 3 computations of x.shape instead of 4, thereby reducing the time usage and memory usage during the computation.

The following is my modification

class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(1, keepdim=True)
        x = x - mean
        variance = x.pow(2).mean(1, keepdim=True)
        scaling = self.weight[:, None, None] / torch.sqrt(variance + self.eps)
        x = scaling * x + self.bias[:, None, None]
        return x

072jiajia avatar Apr 11 '23 17:04 072jiajia