stan
stan copied to clipboard
Offset/multiplier noncentering gets trapped at low sigma
Summary:
In a non-centered normal distribution, if the standard deviation is initialized to a very low value, it crashes to an extremely low value and never recovers. Manual non-centering does not encounter the same problem. This makes non-centering via offset/multiplier unusable for some models (see below).
For more, see: https://discourse.mc-stan.org/t/offset-multiplier-initialization/20712
Description and Reprex
Here's @bbbales2 on the issue, pasted over from discourse (link above):
Here is the offset-multiplier model
parameters{
real<lower = 0> sigma;
real<multiplier=sigma> x;
}
model{
sigma ~ std_normal();
x ~ normal(0, sigma);
}
Here is the manual offset-multiplier model:
parameters{
real<lower = 0> sigma;
real x_raw;
}
transformed parameters {
real x = x_raw * sigma;
}
model{
sigma ~ std_normal();
x ~ normal(0, sigma);
target += log(sigma);
}
Code to run them is
library(tidyverse)
library(cmdstanr)
mod1 = cmdstan_model("mod1.stan")
inits_chain_1 = list(sigma = 1e-20)
fit1 = mod1$sample(chains = 1, init = list(inits_chain_1), iter_sampling = 1000)
fit1$summary()
mod2 = cmdstan_model("mod2.stan")
fit2 = mod2$sample(chains = 1, init = list(inits_chain_1), iter_sampling = 1000)
fit2$summary()
You get output like this for the build-in offset-multiplier:
> fit1$summary()
# A tibble: 3 x 10
variable mean median sd mad q5 q95 rhat ess_bulk
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -9.01e+38 -9.01e+38 0 0 -9.01e+38 -9.01e+38 NA NA
2 sigma 3.12e-20 3.12e-20 0 0 3.12e-20 3.12e-20 NA NA
3 x -1.33e+ 0 -1.33e+ 0 0 0 -1.33e+ 0 -1.33e+ 0 NA NA
# … with 1 more variable: ess_tail <dbl>
And the output with a custom offset-multiplier looks like:
> fit2$summary()
# A tibble: 4 x 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -1.57 -1.24 0.987 0.761 -3.64 -0.563 1.00 292. 430.
2 sigma 0.821 0.703 0.594 0.591 0.0667 2.02 1.00 437. 366.
3 x_raw 0.0115 0.0396 0.983 1.01 -1.64 1.59 1.00 523. 660.
4 x 0.00698 0.0104 0.967 0.547 -1.66 1.51 1.00 492. 503.
This is pretty repeatable that the custom code doesn’t have a problem with inits but the built in does.
Encountering this issue "in the wild"
Some classes of model reliably pinch through very small standard deviations early in warmup. Here's an example--not extreme enough to hit the "sticky boundary", but enough to show why it can be an issue. Notice how in early warmup sigma pinches down to a very low value before recovering. This is consistent across seeds.
data{
int n;
real y[n];
}
parameters{
real<lower = 0> sigma;
real<multiplier=sigma> x[n];
}
model{
sigma ~ std_normal();
x ~ normal(0, sigma);
y ~ normal(x, .01);
}
Code to run:
library(cmdstanr)
pinch <- cmdstan_model("
<img width="853" alt="Screen Shot 2021-05-07 at 12 32 38 PM" src="https://user-images.githubusercontent.com/11272480/117487161-54be9800-af30-11eb-8147-7d482dac60ad.png">
pinch.stan")
set.seed(10)
n <- 50000
pinch_samples <- pinch$sample(data = list(n=n, y=rnorm(n)),
chains = 1, save_warmup = T,
iter_warmup = 30, iter_sampling = 1)
pinch_csv <- read_cmdstan_csv("filename") # $draws() is still really slow on many-parameter models
plot(pinch_csv$warmup_draws[,1,"sigma"])
It's not the end of the world, because manual non-centering still works fine, but this issue makes offset/multiplier noncentering unusable in some of the models I work with.
Current Version:
v2.26.1
Let me know if this actually needs to be filed against math or somewhere else.
This is interesting and seems quite annoying, though I don't know what to do about it. @LuZhangstat @bob-carpenter @avehtari for visibility
Moving this to Stan, as the interface cant really help with this problem I think.