[BUG] TSMixer TimeBatchNorm2d using BatchNorm instead of LayerNorm
Describe the bug
TSMixer implementation overall uses LayerNorm and doesn't give a BatchNorm option. In the paper it was shortly mentioned that the LayerNorm gave better results on smaller batches. However, the TimeBatchNorm2d implementation bases BatchNorm2d which in my tests has been unstable agreeing with the authors. Inhereting from nn.InstanceNorm2d in the class works out of the box and produces stable training and is a better representation of the paper.
To Reproduce
Train a model with TimeBatchNorm2d normalization with lower batch size on non-stationary targets. Model doesn't train (in my case at least).
Replace the base class with nn.InstanceNorm2d - model trains just fine
Expected behavior Model to train at least similarly to LayerNorm option
System (please complete the following information):
- Python version: [e.g. 3.10]
- darts version [e.g. 0.29.0]
Additional context
I think it is best to either add another option using InstanceNorm2d or replace the current implementation all together. Since BatchNorm isn't even an option to choose from currently, it is reasonable to choose the latter.
Note InstanceNorm and LayerNorm are slightly different but very similar in practical terms compared to BatchNorm
https://discuss.pytorch.org/t/is-there-a-layer-normalization-for-conv2d/7595/6
from the author: https://github.com/ditschuk/pytorch-tsmixer/issues/2
Attaching some quick comparisons on my datasets.
Green is LayerNorm, Orange is BatchNorm2d , Grey is InstanceNorm2d.
val_loss
This is the comparison using the tsmixer example notebook from your repo, with full_training enabled. Default vs batch2d vs instance2d . Performance is very similar at least indicating that there is no degradation in performance for other datasets. As the authors point out BatchNorm is probably fine for stationary datasets but isn't necessarily stable in general
Hi @tRosenflanz, and thanks for raising this issue and the investigation 🚀
Quick question:
- are the results you show from a use case where you applied the zeroing of past target and future covariates as mentioned in #2381?
- did you use
use_reversible_instance_norm=Truein your use case (this applies Reversible InstnaceNorm to the input and output) - I also observed some differences between norms, but form my experiments I couldn't rule out batch norm since it indeed performed well on some datasets (as in the ETTh example)
And a question about this :)
I think it is best to either add another option using InstanceNorm2d or replace the current implementation all together.
- what do you mean by replacing the current implementation all together?
-
I am indeed using it with 0ing but without the reversible instance norm flag. I will repeat this experiment without the 0ing later today/tomorrow to make sure that's not the issue. I don't think it should be any different since my past_covariates have several features and it is all concatenated together.
-
After testing with ETTh1/2 I don't think replacing is necessary so you can ignore that comment. I initially thought that the BatchNorm would be failing harder but it clearly holds up fine in some datasets. TimeInstanceNorm2d could be just another possible value for the norm_type
Thanks for the additional info.
-
use_reversible_instance_normis actually applying (reversible) InstanceNorm to the target values. I assume that it could also improve the results/stability from batch norm for your use case.
I do agree that adding InstanceNorm as a norm_type option would make sense. Is this something you'd like to contribute? :)
Retested without 0ing with identical results. Even tried RINorm +TimeBatchNorm2d together and it still isn't working - my guess is other covariates are mangled too much.
RINorm looks interesting but unfortunately I can't use it with my datasets since the target itself is a "leak" hence the look for other normalizations.
I will make a PR to add it, seems fairly trivial
I have arrived to conclusion that this is totally unnecessary and the LayerNormNoBias in the current implementation is in fact identical to InstanceNorm2d . LayerNorm applies on the last 2 dimensions all together and InstanceNorm does so per channel but since the channel is just an artificial single one, they end up identical. This would explain the experimental similarity. It's been a while since I dug into these norm layers so maybe I am missing something that can be done to make it useful - please let me know your thoughts
shape = (32,5,100)
inp = torch.rand(shape)
i2d = TimeInstanceNorm2d()
l = nn.LayerNorm((shape[-2],shape[-1]),elementwise_affine=False)
i2d_out, l_out = i2d(inp),l(inp)
is_close = torch.isclose(i2d_out,l_out,atol=1e-6)
is_close.all()
tensor(True)
Thank you @tRosenflanz for the PR and experimenting with the normalization functions, I agree with your observation that in this context, LayerNormNoBias and InstanceNorm2d are equivalent.
Closing this issue, anyone can feel free to reopen it if they have a counter argument.