darts icon indicating copy to clipboard operation
darts copied to clipboard

[BUG] TSMixer TimeBatchNorm2d using BatchNorm instead of LayerNorm

Open tRosenflanz opened this issue 1 year ago • 6 comments

Describe the bug TSMixer implementation overall uses LayerNorm and doesn't give a BatchNorm option. In the paper it was shortly mentioned that the LayerNorm gave better results on smaller batches. However, the TimeBatchNorm2d implementation bases BatchNorm2d which in my tests has been unstable agreeing with the authors. Inhereting from nn.InstanceNorm2d in the class works out of the box and produces stable training and is a better representation of the paper.

To Reproduce Train a model with TimeBatchNorm2d normalization with lower batch size on non-stationary targets. Model doesn't train (in my case at least). Replace the base class with nn.InstanceNorm2d - model trains just fine

Expected behavior Model to train at least similarly to LayerNorm option

System (please complete the following information):

  • Python version: [e.g. 3.10]
  • darts version [e.g. 0.29.0]

Additional context I think it is best to either add another option using InstanceNorm2d or replace the current implementation all together. Since BatchNorm isn't even an option to choose from currently, it is reasonable to choose the latter. Note InstanceNorm and LayerNorm are slightly different but very similar in practical terms compared to BatchNorm https://discuss.pytorch.org/t/is-there-a-layer-normalization-for-conv2d/7595/6

from the author: https://github.com/ditschuk/pytorch-tsmixer/issues/2

tRosenflanz avatar May 17 '24 00:05 tRosenflanz

Attaching some quick comparisons on my datasets. Green is LayerNorm, Orange is BatchNorm2d , Grey is InstanceNorm2d. Screenshot 2024-05-16 at 5 27 31 PM val_loss Screenshot 2024-05-16 at 5 27 41 PM

This is the comparison using the tsmixer example notebook from your repo, with full_training enabled. Default vs batch2d vs instance2d . Performance is very similar at least indicating that there is no degradation in performance for other datasets. As the authors point out BatchNorm is probably fine for stationary datasets but isn't necessarily stable in general Screenshot 2024-05-16 at 7 06 36 PM

tRosenflanz avatar May 17 '24 00:05 tRosenflanz

Hi @tRosenflanz, and thanks for raising this issue and the investigation 🚀

Quick question:

  • are the results you show from a use case where you applied the zeroing of past target and future covariates as mentioned in #2381?
  • did you use use_reversible_instance_norm=True in your use case (this applies Reversible InstnaceNorm to the input and output)
  • I also observed some differences between norms, but form my experiments I couldn't rule out batch norm since it indeed performed well on some datasets (as in the ETTh example)

And a question about this :)

I think it is best to either add another option using InstanceNorm2d or replace the current implementation all together.

  • what do you mean by replacing the current implementation all together?

dennisbader avatar May 17 '24 07:05 dennisbader

  1. I am indeed using it with 0ing but without the reversible instance norm flag. I will repeat this experiment without the 0ing later today/tomorrow to make sure that's not the issue. I don't think it should be any different since my past_covariates have several features and it is all concatenated together.

  2. After testing with ETTh1/2 I don't think replacing is necessary so you can ignore that comment. I initially thought that the BatchNorm would be failing harder but it clearly holds up fine in some datasets. TimeInstanceNorm2d could be just another possible value for the norm_type

tRosenflanz avatar May 17 '24 07:05 tRosenflanz

Thanks for the additional info.

  1. use_reversible_instance_norm is actually applying (reversible) InstanceNorm to the target values. I assume that it could also improve the results/stability from batch norm for your use case.

I do agree that adding InstanceNorm as a norm_type option would make sense. Is this something you'd like to contribute? :)

dennisbader avatar May 17 '24 08:05 dennisbader

Retested without 0ing with identical results. Even tried RINorm +TimeBatchNorm2d together and it still isn't working - my guess is other covariates are mangled too much.

RINorm looks interesting but unfortunately I can't use it with my datasets since the target itself is a "leak" hence the look for other normalizations.

I will make a PR to add it, seems fairly trivial

tRosenflanz avatar May 17 '24 17:05 tRosenflanz

I have arrived to conclusion that this is totally unnecessary and the LayerNormNoBias in the current implementation is in fact identical to InstanceNorm2d . LayerNorm applies on the last 2 dimensions all together and InstanceNorm does so per channel but since the channel is just an artificial single one, they end up identical. This would explain the experimental similarity. It's been a while since I dug into these norm layers so maybe I am missing something that can be done to make it useful - please let me know your thoughts

shape = (32,5,100)
inp = torch.rand(shape)
i2d = TimeInstanceNorm2d()
l = nn.LayerNorm((shape[-2],shape[-1]),elementwise_affine=False)
i2d_out, l_out = i2d(inp),l(inp)
is_close = torch.isclose(i2d_out,l_out,atol=1e-6)
is_close.all()

tensor(True)

tRosenflanz avatar May 18 '24 20:05 tRosenflanz

Thank you @tRosenflanz for the PR and experimenting with the normalization functions, I agree with your observation that in this context, LayerNormNoBias and InstanceNorm2d are equivalent.

Closing this issue, anyone can feel free to reopen it if they have a counter argument.

madtoinou avatar Aug 28 '24 07:08 madtoinou