ReactiveMP.jl
ReactiveMP.jl copied to clipboard
Fit at the beginning of an array takes 20x more iterations than the rest
As discussed, there is an odd artefact where the fit gets great almost immediately across the model except for the beginning, which then takes up to 20x iterations to converge (linked more to the array size).
Result with 20 iterations (observe the beginning):

Result with 365 iterations (=N)

MWE for a simple curve-fitting scenario
# Generate data to fit
time_max=365
noise_scale=0.05
time_index=collect(1:time_max)/time_max
p=180/time_max
growth=2
offset=0
y=offset .+ sin.(time_index*2π/p) .+ growth.*time_index .+rand(Normal(0,noise_scale),time_max)
plot(y)
# Generate splines to approximate the function
# Note: boundary knots are important to be outside of the needed range to avoid a row of all zeros (which breaks the backprop)
X=Splines2.bs(time_index,df=10,boundary_knots=(-0.01,1.01));
# Build the model
@model function linreg(X,n,dim_x)
T=Float64
y = datavar(T, n)
aux = randomvar(n)
sigma ~ GammaShapeRate(1.0, 1.0)
intercept ~ NormalMeanVariance(0.0, 2.0)
beta ~ MvNormalMeanPrecision(zeros(dim_x), diageye(dim_x))
for i in 1:n
aux[i] ~ intercept + dot(X[i,:], beta)
y[i] ~ NormalMeanPrecision(aux[i], sigma)
end
return beta,aux,y
end
constraints = @constraints begin
q(aux, sigma) = q(aux)q(sigma)
end
# Run inference
@time results = inference(
model = Model(linreg,X,size(X)...),
data = (y = y,),
constraints = constraints,
initmessages = (intercept = vague(NormalMeanVariance),),
initmarginals = (sigma = GammaShapeRate(1.0, 1.0),),
returnvars = (sigma = KeepLast(),beta = KeepLast(), aux = KeepLast()),#,y=KeepLast()),
iterations = 20,
warn = true,
free_energy=true
)
# Plot results
# Note: observe the divergence in the first 50 data points
# It disappears as you increase number of iterations
plot(mean.(results.posteriors[:aux]), ribbon = (results.posteriors[:sigma]|>mean|>inv|>sqrt),label="Fitted")
plot!(y,label="Observed data")