Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Hessian vector products with moderately complex models

Open colinxs opened this issue 5 years ago • 2 comments

I'm attempting to compute Hessian vector products for use with RL algorithms like Natural Policy Gradient or TRPO, but have been entirely unsuccessful.

Following https://github.com/FluxML/Zygote.jl/issues/115, https://github.com/JuliaDiffEq/SparseDiffTools.jl, and elsewhere I was able to compute HVPs for simple models parameterized by a single Array, but the following appears to have issues inferring the type of Dual.

Any help would be greatly appreciated! :)

# Zygote v0.4.1, Flux v0.10.0, ForwardDiff v0.10.7, DiffRules 0.1.0, ZygoteRules 0.2.0

# Julia Version 1.3.0
# Commit 46ce4d7933 (2019-11-26 06:09 UTC)
# Platform Info:
#   OS: Linux (x86_64-pc-linux-gnu)
#   CPU: Intel(R) Core(TM) i9-7960X CPU @ 2.80GHz
#   WORD_SIZE: 64
#   LIBM: libopenlibm
#   LLVM: libLLVM-6.0.1 (ORCJIT, skylake)

using Flux, ForwardDiff, Zygote
using LinearAlgebra

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
    meanNN::M
    logstd::L
end

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    ll = -length(P.logstd) * log(2pi) / 2
    for i = 1:length(action)
        ll -= ((meanact[i] - action[i]) / exp(P.logstd[i]))^2 / 2
        ll -= P.logstd[i]
    end
    ll
end

function flatgrad(f, ps)
    gs = Zygote.gradient(f, ps)
    vcat([vec(gs[p]) for p in ps]...)
end

Base.length(ps::Params) = 228 #sum(length, ps)
Base.size(ps::Params) = (228, ) #(length(ps), )
Base.eltype(ps::Params) = Float32

function hessian_vector_product(f,ps,v)
    g = let f=f
        ps -> flatgrad(f, ps)::Vector{Float32}
    end
    gvp = let g=g, v=v
        ps -> (g(ps)⋅v)::Vector{Float32}
    end
    Zygote.forward_jacobian(gvp, ps)[2]
end

function test()
    policy = Flux.paramtype(Float32, DiagGaussianPolicy(Flux.Chain(Dense(4, 32), Dense(32, 2)), zeros(2)))
    ps = Flux.params(policy)
    v = rand(Float32, sum(length, ps))
    feat = rand(Float32, 4)
    act = rand(Float32, 2)
    f = let policy=policy, feat=feat, act=act
        () -> loglikelihood(policy, feat, act)
    end
    hessian_vector_product(f, ps, v)
end

Calling test() yields:

an_dual.
Stacktrace:
 [1] throw_cannot_dual(::Type) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:36
 [2] ForwardDiff.Dual{Nothing,Any,12}(::Array{Float32,2}, ::ForwardDiff.Partials{12,Any}) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:18
 [3] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:55 [inlined]
 [4] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:62 [inlined]
 [5] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:68 [inlined]
 [6] (::Zygote.var"#1565#1567"{12,Int64})(::Array{Float32,2}, ::Int64) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:8
 [7] (::Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}})(::Tuple{Array{Float32,2},Int64}) at ./generator.jl:36
 [8] iterate at ./generator.jl:47 [inlined]
 [9] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Params,UnitRange{Int64}}},Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}}}) at ./array.jl:622
 [10] map at ./abstractarray.jl:2155 [inlined]
 [11] seed at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:7 [inlined] (repeats 2 times)
 [12] forward_jacobian(::var"#340#342"{var"#339#341"{var"#343#344"{DiagGaussianPolicy{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Array{Float32,1}},
Array{Float32,1},Array{Float32,1}}},Array{Float32,1}}, ::Params, ::Val{12}) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:23
 [13] forward_jacobian(::Function, ::Params) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:40
 [14] hessian_vector_product(::Function, ::Params, ::Array{Float32,1}) at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:52
 [15] test() at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:64
 [16] top-level scope at REPL[41]:1

colinxs avatar Dec 04 '19 02:12 colinxs

So with some more fiddling and realizing that ForwardDiff only works with AbstractArray inputs I was able to get the above working (not yet checked for correctness). I was able to get around this through a combination of RecursiveArrayTools and Flux.fmap. I also had to modify the loglikelihood expression to get rid of the .^2 expression, which appears to be related to https://github.com/FluxML/Zygote.jl/issues/405.

While I haven't checked the result for correctness, the solution itself isn't too ugly :).

I was able to make this work without the custom seeding and instead explicitly calculating the vector product, but the solution is messier/slower so I won't bother posting it.

I'll circle back once I've cleaned things up a bit and verified the result is correct, but until then if anyone has suggestions do let me know! This is fairly critical ability for anyone doing research in ML, RL, etc.

using Flux, ForwardDiff, Zygote, RecursiveArrayTools, Random, LinearAlgebra
using Zygote: Params, Grads
using MacroTools: @forward

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
    meanNN::M
    logstd::L
end

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    # broken (possibly related to https://github.com/FluxML/Zygote.jl/issues/405)
    #zs = ((meanact .- action) ./ exp.(P.logstd)) .^ 2
    # works
    zs = (meanact .- action) ./ exp.(P.logstd)
    zs = zs .* zs

    ll = -sum(zs)/2 - sum(P.logstd) - length(P.logstd) * log(2pi) / 2
    ll
end

flatgrad(gs::Grads, ps::Params) = ArrayPartition((gs[p] for p in ps if !isnothing(gs[p]))...)

function flat_hessian_vector_product(feat, act, policy, vs::ArrayPartition)
    ps = Flux.params(policy)

    i = 1
    dualpol = Flux.fmap(policy) do p
        if p in ps.params
            p = ForwardDiff.Dual{Nothing}.(p, vs.x[i])
            i += 1
        end
        p
    end
    dualps = params(dualpol)

    G = let feat=feat, act=act
        function (ps)
            gs = gradient(() -> loglikelihood(dualpol, feat, act), ps)
            flatgrad(gs, ps)
        end
    end

    ForwardDiff.partials.(G(dualps), 1)
end


function test_flathvp(T::DataType=Float32)
    Random.seed!(1)

    dobs, dact = 4, 2
    policy = DiagGaussianPolicy(Chain(Dense(dobs, 32), Dense(32, 32), Dense(32, dact)), zeros(dact))
    policy = Flux.paramtype(T, policy)

    v = ArrayPartition((rand(size(p)...) for p in params(policy))...)
    feat = rand(T, 4)
    act = rand(T, 2)

    @time flat_hessian_vector_product(feat, act, policy, v)
end

colinxs avatar Dec 06 '19 06:12 colinxs

Can you try using Lux and ComponentArrays?

YichengDWu avatar Jul 26 '22 01:07 YichengDWu