Soss.jl icon indicating copy to clipboard operation
Soss.jl copied to clipboard

Soss + Stheno

Open willtebbutt opened this issue 6 years ago • 8 comments

I'm trying to make the basic features of Stheno play nicely with Soss, and struggling to figure out what I'm doing wrong. In particular, here's some code that I'm trying

using Soss, Stheno
m = @model x begin
    log_l ~ Normal(0, 1)
    f = Stheno.GP(eq(exp(-l)), Stheno.GPC())
    y ~ f(x, 0.01)
end
rand(m(x=randn(10)))

which produces the following error message:

ERROR: UndefVarError: Stheno not defined
Stacktrace:
 [1] getproperty at ./Base.jl:13 [inlined]
 [2] macro expansion at /home/wct23/.julia/packages/GeneralizedGenerated/x3uMp/src/closure_conv.jl:155 [inlined]
 [3] _rand(::Model{NamedTuple{(:x,),T} where T<:Tuple,begin
    log_l ~ Normal(0, 1)
    f = Stheno.GP(eq(exp(-l)), Stheno.GPC())
    y ~ f(x, 0.01)
end}, ::NamedTuple{(),Tuple{}}) at /home/wct23/.julia/packages/GeneralizedGenerated/x3uMp/src/closure_conv.jl:155
 [4] rand(::Soss.JointDistribution{NamedTuple{(),Tuple{}},NamedTuple{(:x,),T} where T<:Tuple,begin
    log_l ~ Normal(0, 1)
    f = Stheno.GP(eq(exp(-l)), Stheno.GPC())
    y ~ f(x, 0.01)
end}) at /home/wct23/.julia/packages/Soss/qX8ds/src/rand.jl:8
 [5] top-level scope at REPL[7]:1

Versions: Julia: 1.3.0-rc3.0 Soss: 0.6.0 Stheno: master -- this commit

willtebbutt avatar Oct 24 '19 08:10 willtebbutt

Hi @willtebbutt ,

Sorry I had missed this. Stheno integration would be great! There's currently a GeneralizedGenerated.jl limitation on Scoping, here's a quick workaround:

julia> m = @model x,S begin
           log_l ~ Normal(0, 1)
           f = S.GP(S.eq(exp(-log_l)), S.GPC())
           y ~ f(x, 0.01)
       end;

julia> rand(m(x=randn(10), S=Stheno))
(x = [-0.04074023995129731, 0.3870024821608242, -1.233541055727363, 0.6194909532814918, -1.631816569217919, -0.7451613144346232, -0.4107313103816707, 1.222974686587564, 0.4452615546836251, 0.7908933723879386],
 S = Stheno,
 f = GP{Stheno.ZeroMean{Float64},Stheno.Stretched{Float64,Stheno.EQ}}(Stheno.ZeroMean{Float64}(), Stheno.Stretched{Float64,Stheno.EQ}(2.2888525734792333, Stheno.EQ()), 1, GPC(1)),
 log_l = -0.8280506323777755,
 y = [0.06331057872393836, 0.6984615009039875, 0.9521603017770763, 1.2487380090603808, 0.7682254730656243, 1.115083042576203, 0.5635710597417615, 1.2848657321991466, 0.9375305632328326, 1.6434547286217749],)

cscherrer avatar Oct 31 '19 00:10 cscherrer

I've added Stheno to Project.toml, but it's not working because Zygote won't precompile. What version/branch of it are you using?

cscherrer avatar Nov 04 '19 00:11 cscherrer

Of Zygote, 0.3.4 (the latest release)

willtebbutt avatar Nov 04 '19 10:11 willtebbutt

Looks like Zygote is very picky about IRTools versions, but doesn't properly bound it. I was able to get it working by ]adding Zygote and IRTools. BTW there's now a Zygote 0.4.0 :)

Also, I just realized that though I ]added Stheno, I'm still not using it, so really nothing has improved :(

I'd like to integrate it better, but I'm a bit worried about making sure to constrain versions of everything just the right amount.

Here's what ended up working:

Project Soss v0.7.0
    Status `~/git/jl/Soss/Project.toml`
  [0bf59076] AdvancedHMC v0.2.7
  [76274a88] Bijectors v0.4.0
  [163ba53b] DiffResults v0.0.4
  [31c24e10] Distributions v0.19.2
  [bbc10e6e] DynamicHMC v2.1.0
  [f6369f11] ForwardDiff v0.10.5
  [6b9d7cbe] GeneralizedGenerated v0.1.3
  [86223c79] Graphs v0.10.3
  [7869d1d1] IRTools v0.3.0
  [c8e1da08] IterTools v1.2.0
  [6fdf6af0] LogDensityProblems v0.9.1
  [d8e11817] MLStyle v0.3.1
  [1914dd2f] MacroTools v0.5.1
  [0987c9cc] MonteCarloMeasurements v0.5.3
  [d9ec5142] NamedTupleTools v0.11.0
  [438e738f] PyCall v1.91.2
  [189a3867] Reexport v0.2.0
  [c5292f4c] ResumableFunctions v0.5.1
  [37e2e3b7] ReverseDiff v0.3.1
  [55797a34] SimpleGraphs v0.3.0
  [ec83eff0] SimplePartitions v0.2.0
  [b2aef97b] SimplePosets v0.0.3
  [4c63d2b9] StatsFuns v0.8.0
  [8188c328] Stheno v0.3.2
  [24249f21] SymPy v1.0.7
  [84d833dd] TransformVariables v0.3.8
  [e88e6eb3] Zygote v0.4.0
  [37e2e46d] LinearAlgebra 
  [de0858da] Printf 
  [9a3f8284] Random 
  [10745b16] Statistics 

Not sure what's constraining Distributions (I think the latest is v21 or so) but that's... not great

cscherrer avatar Nov 04 '19 15:11 cscherrer

Hmm okay. Distributions doesn't currently work with v0.4.0 of Zygote due to a SpecialFunctions dependency, so I'm currently running with Zygote v0.3.4 and ZygoteRules v0.1.0. Also note sure what's constraining Distributions. I lower bound it to v0.21.0inStheno`, but I could modify that if it would be helpful.

willtebbutt avatar Nov 04 '19 16:11 willtebbutt

In case it makes things easier, I've added you as a collaborator, and created https://github.com/cscherrer/Soss.jl/tree/stheno

cscherrer avatar Nov 04 '19 16:11 cscherrer

This works in v0.8:

julia> m = @model x begin
           log_l ~ Normal(0, 1)
           f = Stheno.GP(eq(exp(-log_l)), Stheno.GPC())
           y ~ f(x, 0.01)
       end
@model x begin
        log_l ~ Normal(0, 1)
        f = Stheno.GP(eq(exp(-log_l)), Stheno.GPC())
        y ~ f(x, 0.01)
    end


julia> rand(m(x=randn(10)))
(x = [-0.18116708254415098, -0.19534937893921914, -2.5787646648351634, -1.2215879425363827, -2.176712526585
0202, -0.6844208442501819, 1.7116200174728649, -1.211904330744371, 0.9237978711932542, 2.5672978727145894],
 f = GP{Stheno.ZeroMean{Float64},Stheno.Stretched{Float64,Stheno.EQ}}(Stheno.ZeroMean{Float64}(), Stheno.St
retched{Float64,Stheno.EQ}(1.4657457835104046, Stheno.EQ()), 1, Stheno.GPC(1)), log_l = -0.3823641801808569, y = [-1.1745283754760036, -1.159683442935461, 0.20763478852991196, -1.8210157522455404, 0.43808611634562833, -1.9699497363285094, 1.3864736406342097, -1.954932802749259, 0.35645511546448283, 1.2815533250693458])

cscherrer avatar Nov 29 '19 17:11 cscherrer

Hi @willtebbutt , I think I had closed this too early. So, this works:

using Soss, Stheno
m = @model x begin
    log_l ~ Normal(0, 1)
    f = Stheno.GP(stretch(EQ(),exp(-log_l)), Stheno.GPC())
    y ~ f(x, 0.01)
end

x = [1:10;]

then

julia> r = rand(m(x=x))
(f = GP{Stheno.ZeroMean{Float64},Stheno.Stretched{Array{Float64,1},EQ,typeof(identity)}}(Stheno.ZeroMean{Float64}(), Stheno.Stretched{Array{Float64,1},EQ,typeof(identity)}([0.1503990725174873], EQ(), identity), 1, GPC(1)), log_l = 1.8944630342574382, y = [-0.6385019767952739, -0.6568345287713035, -0.7614124412019684, -0.6573105642557342, -0.806632268687059, -0.8402795302144008, -0.7312437035670117, -0.6445296881421446, -0.616990857237575, -0.5990118606193565])

But there's no xform method, so sampling didn't work. It does work if I change

function xform(d, _data=NamedTuple())
    if hasmethod(support, (typeof(d),))
        return asTransform(support(d)) 
    end

    error("Not implemented:\nxform($d)")
end

to

function xform(d, _data=NamedTuple())
    if hasmethod(support, (typeof(d),))
        return asTransform(support(d)) 
    elseif hasmethod(length, (typeof(d),)) 
        return as(Vector, length(d))
    end

    error("Not implemented:\nxform($d)")
end

Then, e.g.,

julia> xform(m(x=x), (y=r.y,))
TransformVariables.TransformTuple{NamedTuple{(:log_l,),Tuple{TransformVariables.Identity}}}((log_l = asℝ,), 1)

julia> dynamicHMC(m(x=x), (y=r.y,)) |> particles
(log_l = 2.08 ± 0.48,)

The second hasmethod check seems a little hacky, and I haven't finalized it yet. An alternative would be a support method for a FiniteGP. What do you think?

cscherrer avatar Sep 14 '20 17:09 cscherrer