Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

Nested transformation with Shift does not work for Matrix output

Open mgmverburg opened this issue 3 years ago • 1 comments

Basically, what I want to achieve is a logit transformed outcome for example, but I want to allow covariates to have different effects, and hence shift the mean for each of the entries of the matrix. I had a way to do this before already, but this involved an arraydist with a for-loop, and when optimizing I noticed that that approach actually caused a slowdown (among other things, it made type-instabilities appear I believe when using code_warntype). Therefore, I wanted to find a way to do this in one shot. Most of the things work though, but for some reason this specific use-case seemed to not work, even though it felt like in theory it should.

To make it perhaps slightly weirder, the shift bijector does work with 2D data like in model_1 in the code below, which throws no error. However, when adding a layer like a logit transform to wrap around it, then it throws an error that I listed below the code.

using Bijectors, Turing, LinearAlgebra, using Random

M, N = 8, 20
output = rand(LogitNormal(0, 1), M, N)

@model function test_1(output, M, N)
    mvn = MvNormal(zeros(M), LinearAlgebra.I)
    z ~ filldist(Normal(0, 1), M, N)
    output ~ transformed(filldist(mvn, N), Bijectors.Shift(z))
end

model_1 = test_1(output, M, N)
chain_1 = sample(model_1, NUTS(0.65), 10)


@model function test_2(output, M, N)
    mvn = MvNormal(zeros(M), LinearAlgebra.I)
    b = inv(Bijectors.Logit{2}(0.0, 1.0))
    z ~ filldist(Normal(0, 1), M, N)
    output ~ transformed(transformed(filldist(mvn, N), Bijectors.Shift(z)), b)
end

model_2 = test_2(output, M, N)
chain_2 = sample(model_2, NUTS(0.65), 10)
Error message ERROR: MethodError: no method matching _logabsdetjac_shift(::Array{Float64,2}, ::Array{Float64,2}, ::Val{2}) Closest candidates are: _logabsdetjac_shift(::T1, ::AbstractArray{T2,2}, ::Val{2}) where {T1<:union abstractarray where t t2 at _logabsdetjac_shift a var tracker.trackedreal ::abstractarray ::val ... stacktrace: logabsdetjac ::array logpdf_with_trans ::bool _logpdf logpdf loglikelihood observe ::bijectors.transformeddistribution ::dynamicppl.threadsafevarinfo _tilde tilde ::dynamicppl.samplefromuniform tilde_observe ::abstractppl.varname ::tuple . ::dynamicppl.model :m :n ::dynamicppl.defaultcontext ::int64 macro expansion _evaluate evaluate_threadsafe ::dynamicppl.varinfo dynamicppl.varinfo step ::dynamicppl.sampler resume_from::nothing kwargs::base.iterators.pairs with_logstate ::any with_logger ::loggingextras.teelogger with_progresslogger ::module ::logging.consolelogger mcmcsample progress::bool progressname::string callback::nothing discard_initial::int64 thinning::int64 chain_type::type sample nadapts::int64 discard_adapt::bool ::nuts top-level scope repl>

So I was able to fix this (for my specific case that I encountered an error with) by simply adding: Bijectors._logabsdetjac_shift(a::T1, x::AbstractMatrix{T2}, ::Val{2}) where {T1<:Union{Real, AbstractMatrix}, T2<:Real} = zero(T2)

But I am not sure if that is the best/cleanest fix for the package as a whole, or whether this covers just 1 use-case again.

Bijectors version 0.9.7, Turing 0.16.0

mgmverburg avatar Aug 01 '21 14:08 mgmverburg

Ah yes this is a missing definition. But you're solution is correct :+1: Once #183 has gone through, these things shouldn't happen.

torfjelde avatar Aug 02 '21 10:08 torfjelde