[BUG] RIN does not seem to be scale invariant
Describe the bug Reversible instance norm doesn't produce the same results with rescaling the inputs followed by inverting the scaling on the output, i.e. the RIN isn't correctly implemented and is actually sensitive to the input scale.
To Reproduce
model_base = NBEATSModel.load_from_checkpoint(...) # Needs to have an RIN layer.
ts = <TimeSeries of at least `input_length`>
rescales = [0.00001, 0.0001, 1.0, 100.0]
preds = []
for scale in rescales:
input = ts * scale
prd_out = model_base.predict(n=output_length, series=input)
preds.append(prd_out.values() / scale)
# Error: The lines in this figure don't line up for small scales implying it is not actually scale invariant.
plt.figure()
for i, p in enumerate(preds):
plt.plot(p[-30:], linestyle="--", label=rescales[i], linewidth=0.2 * (i + 2))
plt.legend()
# Dropping the eps value seems to fix it.
model_base.model.rin.eps = 1e-15
preds = []
for scale in rescales:
input = ts * scale
prd_out = model_base.predict(n=output_length, series=input)
preds.append(prd_out.values() / scale)
# The lines in this figure now line up for small scales implying it is scale invariant.
plt.figure()
for i, p in enumerate(preds):
plt.plot(p[-30:], linestyle=":", label=rescales[i], linewidth=0.2 * (i + 2))
plt.legend()
Expected behavior For the code snippet above, the lines should line up, i.e. the predictions should be the same with/without rescaling.
System (please complete the following information):
- Python version: 3.11
- darts version 0.31.0
Additional context NA
RINorm is only scale invariant the variance of the sample is much greater than rin.eps. eps is added to the sample variance for numerical stability before scaling to avoid numerical instability in low-variance samples. If this was not done, the input would be divided by zero and converted to nans whenever the lookback was constant.
I agree 10E-5 is a rather small default value for epsilon but it is what the original paper used.
class RINorm(nn.Module):
def forward(self, x: torch.Tensor):
# at the beginning of `PLForecastingModule.forward()`, `x` has shape
# (batch_size, input_chunk_length, n_targets).
# select all dimensions except batch and input_dim (0, -1)
# TL;DR: calculate mean and variance over all dimensions except batch and input_dim
calc_dims = tuple(range(1, x.ndim - 1))
self.mean = torch.mean(x, dim=calc_dims, keepdim=True).detach()
self.stdev = torch.sqrt(
torch.var(x, dim=calc_dims, keepdim=True, unbiased=False) + self.eps
).detach()
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
Thanks for your response. This makes some sense to me. I don't quite understand why it shouldn't be exact at low scales, though, because it is just multiplying and dividing by a number? Is it because of numerical precision at low scales? Then I don't understand why dropping eps fixes the problem? I have tested this with/without the affine layer as well and the pathology is observed for both.
On a side note: Looking through the original code (which was re-implemented here), I think the normalization with affine corrections can also be improved:
def _normalize(self, x):
x = x - self.mean
x = x / self.stdev
if self.affine:
x = x * self.affine_weight
x = x + self.affine_bias
return x
def _denormalize(self, x):
if self.affine:
x = x - self.affine_bias
x = x / (self.affine_weight + self.eps*self.eps)
x = x * self.stdev
x = x + self.mean
return x
(a) I don't think (self.affine_weight + self.eps*self.eps) actually protects against dividing by zero, because the unconstrained affine_weight could just be equal to self.eps*self.eps. I think it should be a sign-preserving clip (i.e. w = sgn(w) * max(eps, w)).
(b) Whatever that operation is should also be applied in the _normalize function (or the DARTS equivalent) so that that operation is (more!) exactly invertible.
The norm is exactly invertible at all scales, but doesn't produce unit variance for inputs with very low scales, because it doesn't scale by the sample std, it scales by the root of the sample variance plus epsilon. If a sample variance is equal to epsilon, your NBEATSModel will receive a sample with a scale of 1/sqrt(2) rather than 1 and make accordingly different predictions. I believe this is a fundamental limitation of the method as reported.
@andrewwarrington re: your first point, I think you are right.
The weights are initialized at 1, and under normal circumstances should probably not be small or negative. As currently implemented the denormalize function doesn't not invert the normalize fucntion when affine_weight ~ eps**2, is undefined when -affine_weight == eps**2, and will flip the sign when -eps**2 < affine_weight < 0, which seems obviously undesirable to me.
I'm not an expert on affine transformations but I think clipping to eps**2 on normalize and unnormalize to avoid divergence and preserve the invertibility makes a lot of sense to me.