Bijectors.jl icon indicating copy to clipboard operation
Bijectors.jl copied to clipboard

softplus bijector?

Open willtebbutt opened this issue 5 years ago • 3 comments

If someone has some time, a softplus Bijector would be cool to have.

willtebbutt avatar Aug 21 '20 16:08 willtebbutt

I'm currently not able to try this out, but the following should do the trick:

using StatsFuns

struct Softplus <: Bijector{0} end

# Forward
(::Softplus)(x::Real) = softplus(x)
(::Softplus)(x::AbstractArray{<:Real}) = softplus.(x)

# Backward
(::Inversed{<:Softplus})(y::Real) = invsoftplus(y)
(::Inversed{<:Softplus})(y::AbstractArray{<:Real}) = invsoftplus.(y)

# logabsdetjac (forward)
logabsdetjac(b::Softplus, x::Real) = x - log(1 + exp(x))  # I THINK this is right, haven't written it down

I'll make a PR and such when I'm back at the desk :+1:

Sidenote: I think I realized a way we can avoid this code-duplication for making things work when we look at "batches" (i.e. define the (::Bijector{0})(x::AbstractArray{<:Real}) once and have all <:Bijector{0} inherit this 🎉

torfjelde avatar Aug 21 '20 22:08 torfjelde

Sidenote: I think I realized a way we can avoid this code-duplication for making things work when we look at "batches" (i.e. define the (::Bijector{0})(x::AbstractArray{<:Real}) once and have all <:Bijector{0} inherit this tada

An alternative (as done, e.g., by Distributions.logpdf) would be to demand from all users to specify broadcasting explicitly. I.e., users should just call Softplus().(x) where x is a vector.

devmotion avatar Aug 21 '20 22:08 devmotion

logabsdetjac(b::Softplus, x::Real) = x - log(1 + exp(x)) # I THINK this is right, haven't written it down

Assuming it's correct (haven't thought about it), one should probably implement it as

logabsdetjac(::Softplus, x::Real) = -log1pexp(-x)

devmotion avatar Aug 21 '20 22:08 devmotion

This is provided now by LogExpFunctions's ChangesOfVariables extension: https://github.com/JuliaStats/LogExpFunctions.jl/blob/a1c4fda2b9cc4c59c184648c0cfc7f694c415bf3/ext/LogExpFunctionsChangesOfVariablesExt.jl#L7-L10

sethaxen avatar Sep 19 '23 18:09 sethaxen