Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fit at the beginning of an array takes 20x more iterations than the rest #146

Open
svilupp opened this issue Jun 18, 2022 · 0 comments
Open

Comments

@svilupp
Copy link

svilupp commented Jun 18, 2022

As discussed, there is an odd artefact where the fit gets great almost immediately across the model except for the beginning, which then takes up to 20x iterations to converge (linked more to the array size).

Result with 20 iterations (observe the beginning):
image

Result with 365 iterations (=N)
image

MWE for a simple curve-fitting scenario

# Generate data to fit
time_max=365
noise_scale=0.05

time_index=collect(1:time_max)/time_max
p=180/time_max
growth=2
offset=0
y=offset .+ sin.(time_index*2π/p) .+ growth.*time_index .+rand(Normal(0,noise_scale),time_max)
plot(y)

# Generate splines to approximate the function
# Note: boundary knots are important to be outside of the needed range to avoid a row of all zeros (which breaks the backprop)
X=Splines2.bs(time_index,df=10,boundary_knots=(-0.01,1.01));

# Build the model
@model function linreg(X,n,dim_x)

    T=Float64
    
    y = datavar(T, n)
    aux = randomvar(n)

    sigma  ~ GammaShapeRate(1.0, 1.0)
    intercept ~ NormalMeanVariance(0.0, 2.0)
    beta  ~ MvNormalMeanPrecision(zeros(dim_x), diageye(dim_x))
    
    for i in 1:n
        aux[i] ~ intercept + dot(X[i,:], beta)
        y[i] ~ NormalMeanPrecision(aux[i], sigma) 
    end

    return beta,aux,y
end

constraints = @constraints begin 
    q(aux, sigma) = q(aux)q(sigma)
end

# Run inference
@time results = inference(
    model = Model(linreg,X,size(X)...),
    data  = (y = y,),
    constraints = constraints,
    initmessages = (intercept = vague(NormalMeanVariance),),
    initmarginals = (sigma = GammaShapeRate(1.0, 1.0),),
    returnvars   = (sigma = KeepLast(),beta = KeepLast(), aux = KeepLast()),#,y=KeepLast()),
    iterations   = 20,
    warn = true,
    free_energy=true
)

# Plot results
# Note: observe the divergence in the first 50 data points
# It disappears as you increase number of iterations
plot(mean.(results.posteriors[:aux]), ribbon = (results.posteriors[:sigma]|>mean|>inv|>sqrt),label="Fitted")
plot!(y,label="Observed data")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant