segment-anything
segment-anything copied to clipboard
Question about computational improvement of LayerNorm2d
Thank you for your excellent work.
I would like to inquire whether a certain modification could potentially enhance the speed of your model's computation.
The LayerNorm2d
involves the operation [(x - u) / s] * w + b
.
If we perform this operation by combining the /s
and *w
into *(w/s)
, it will require only 3 computations of x.shape
instead of 4, thereby reducing the time usage and memory usage during the computation.
The following is my modification
class LayerNorm2d(nn.Module):
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(1, keepdim=True)
x = x - mean
variance = x.pow(2).mean(1, keepdim=True)
scaling = self.weight[:, None, None] / torch.sqrt(variance + self.eps)
x = scaling * x + self.bias[:, None, None]
return x