ReactiveMP.jl icon indicating copy to clipboard operation
ReactiveMP.jl copied to clipboard

Fit at the beginning of an array takes 20x more iterations than the rest

Open svilupp opened this issue 3 years ago • 0 comments

As discussed, there is an odd artefact where the fit gets great almost immediately across the model except for the beginning, which then takes up to 20x iterations to converge (linked more to the array size).

Result with 20 iterations (observe the beginning): image

Result with 365 iterations (=N) image

MWE for a simple curve-fitting scenario

# Generate data to fit
time_max=365
noise_scale=0.05

time_index=collect(1:time_max)/time_max
p=180/time_max
growth=2
offset=0
y=offset .+ sin.(time_index*2π/p) .+ growth.*time_index .+rand(Normal(0,noise_scale),time_max)
plot(y)

# Generate splines to approximate the function
# Note: boundary knots are important to be outside of the needed range to avoid a row of all zeros (which breaks the backprop)
X=Splines2.bs(time_index,df=10,boundary_knots=(-0.01,1.01));

# Build the model
@model function linreg(X,n,dim_x)

    T=Float64
    
    y = datavar(T, n)
    aux = randomvar(n)

    sigma  ~ GammaShapeRate(1.0, 1.0)
    intercept ~ NormalMeanVariance(0.0, 2.0)
    beta  ~ MvNormalMeanPrecision(zeros(dim_x), diageye(dim_x))
    
    for i in 1:n
        aux[i] ~ intercept + dot(X[i,:], beta)
        y[i] ~ NormalMeanPrecision(aux[i], sigma) 
    end

    return beta,aux,y
end

constraints = @constraints begin 
    q(aux, sigma) = q(aux)q(sigma)
end

# Run inference
@time results = inference(
    model = Model(linreg,X,size(X)...),
    data  = (y = y,),
    constraints = constraints,
    initmessages = (intercept = vague(NormalMeanVariance),),
    initmarginals = (sigma = GammaShapeRate(1.0, 1.0),),
    returnvars   = (sigma = KeepLast(),beta = KeepLast(), aux = KeepLast()),#,y=KeepLast()),
    iterations   = 20,
    warn = true,
    free_energy=true
)

# Plot results
# Note: observe the divergence in the first 50 data points
# It disappears as you increase number of iterations
plot(mean.(results.posteriors[:aux]), ribbon = (results.posteriors[:sigma]|>mean|>inv|>sqrt),label="Fitted")
plot!(y,label="Observed data")

svilupp avatar Jun 18 '22 06:06 svilupp