SossMLJ.jl icon indicating copy to clipboard operation
SossMLJ.jl copied to clipboard

We need to fix the implementation of `MLJModelInterface.predict` for classifiers

Open DilumAluthge opened this issue 4 years ago • 5 comments

In #91, I added an incorrect implementation of MMI.predict for classifiers. This allowed me to finish the pipeline, do cross validation, add additional tests, etc.

But we should fix the implementation of MMI.predict before we register the package.

Basically, we need a method MMI.predict that outputs a Vector{UnivariateFinite}.

DilumAluthge avatar Aug 28 '20 06:08 DilumAluthge

Just some notes here as I think through this...

predict_joint gives us a SossMLJPredictor. For the multinomial example, the fields have types

julia> [p => typeof(getproperty(predictor_joint, p)) for p in propertynames(predictor_joint)]
4-element Array{Pair{Symbol,DataType},1}:
 :model => SossMLJModel{UnivariateFinite,Soss.Model{NamedTuple{(:X, :pool),T} where T<:Tuple,TypeEncoding(begin
    k = length(pool.levels)
    p = size(X, 2)
    β ~ Normal(0.0, 1.0) |> iid(p, k)
    η = X * β
    μ = NNlib.softmax(η; dims = 2)
    y_dists = UnivariateFinite(pool.levels, μ; pool = pool)
    n = size(X, 1)
    y ~ For((j->begin
                    y_dists[j]
                end), n)
end),TypeEncoding(Main)},NamedTuple{(:pool,),Tuple{CategoricalPool{String,UInt8,CategoricalValue{String,UInt8}}}},typeof(dynamicHMC),Symbol,typeof(SossMLJ.default_transform)}
  :post => Array{NamedTuple{(:β,),Tuple{Array{Float64,2}}},1}
  :pred => Soss.Model{NamedTuple{(:X, :pool, :β),T} where T<:Tuple,TypeEncoding(begin
    η = X * β
    μ = NNlib.softmax(η; dims = 2)
    y_dists = UnivariateFinite(pool.levels, μ; pool = pool)
    n = size(X, 1)
    y ~ For((j->begin
                    y_dists[j]
                end), n)
end),TypeEncoding(Main)}
  :args => NamedTuple{(:X, :pool),Tuple{Array{Float64,2},CategoricalPool{String,UInt8,CategoricalValue{String,UInt8}}}}

Abstractly, we need a mixture of instantiations of pred, one component for each value of post.

There's a bit more to it though, because we need to return just the last distribution, so the result will (in most cases) no longer be a Soss model. This part will require a new Soss method, which I can put together.

This will get us to "mixture over the response distributions". Then for the special case of UnivariateFinite (and also Categorical and Multinomial), we'll need a method that says a mixture of UnivariateFinites is just another UnivariateFinite.

This won't just be any mixture, the components will have equal weight. I have an EqualMix in Soss that will at least be a good starting point for this.

cscherrer avatar Aug 30 '20 00:08 cscherrer

Think I'm getting close...

Say you start with from p=predictor_joint from example-linear-regression.jl.

Then we can mess with the predictive distribution

julia> p.pred
@model (X, σ, β) begin
        η = X * β
        μ = η
        y ~ For(identity, Normal.(μ, σ))
    end

to get

julia> newpred = Soss.before(Soss.withdistributions(p.pred), p.model.response; strict=true, inclusive=false)
@model (X, σ, β) begin
        η = X * β
        μ = η
        _y_dist = For(identity, Normal.(μ, σ))
    end

Then with a little marginals function like

function marginals(d::For)
    return d.f.(d.θ)
end

we can get

julia> mar = marginals(rand(newpred(merge(p.args, particles(p.post))))._y_dist);

julia> typeof(mar)
Array{Normal{Particles{Float64,1000}},1}

julia> mar[1]
Normal{Particles{Float64,1000}}(
μ: -0.229 ± 0.0094
σ: 0.142 ± 0.0032
)

This is not quite what we want, but seems very close. And (as always with particles) I really like how clean and easy-to-read the representation is. Maybe we need something like

struct ParticleMixture{D,X} <: Distribution
    f :: D # the constructor, e.g. `Normal`
    pars :: X
end

So this would have the same data as f(pars...), but would allow us to write proper rand and logpdf methods. Hmm..., actually this would be more natural as part of MonteCarloMethods. I'll think a bit more and then start an issue there for it.

cscherrer avatar Sep 06 '20 20:09 cscherrer

Just remembered about https://github.com/baggepinnen/MonteCarloMeasurements.jl/issues/22

Lots of great background there, need to reread it myself :)

@DilumAluthge let's go ahead with the release and update as this moves ahead.

cscherrer avatar Sep 15 '20 21:09 cscherrer

I have created the Prediction project to keep track of progress on this issue.

DilumAluthge avatar Sep 15 '20 22:09 DilumAluthge

I'm going to mark this as potentially breaking, since it will probably require some changes to the return types of public functions.

DilumAluthge avatar Sep 15 '20 23:09 DilumAluthge