Hessian vector products with moderately complex models

Open colinxs opened this issue 5 years ago • 2 comments

I'm attempting to compute Hessian vector products for use with RL algorithms like Natural Policy Gradient or TRPO, but have been entirely unsuccessful.

Following https://github.com/FluxML/Zygote.jl/issues/115, https://github.com/JuliaDiffEq/SparseDiffTools.jl, and elsewhere I was able to compute HVPs for simple models parameterized by a single Array, but the following appears to have issues inferring the type of Dual.

Any help would be greatly appreciated! :)

using Flux, ForwardDiff, Zygote
using LinearAlgebra

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    ll = -length(P.logstd) * log(2pi) / 2
    for i = 1:length(action)
        ll -= ((meanact[i] - action[i]) / exp(P.logstd[i]))^2 / 2
        ll -= P.logstd[i]

function flatgrad(f, ps)
    gs = Zygote.gradient(f, ps)
    vcat([vec(gs[p]) for p in ps]...)

Base.length(ps::Params) = 228 #sum(length, ps)
Base.size(ps::Params) = (228, ) #(length(ps), )
Base.eltype(ps::Params) = Float32

function hessian_vector_product(f,ps,v)
    g = let f=f
        ps -> flatgrad(f, ps)::Vector{Float32}
    gvp = let g=g, v=v
        ps -> (g(ps)⋅v)::Vector{Float32}
    Zygote.forward_jacobian(gvp, ps)[2]

function test()
    policy = Flux.paramtype(Float32, DiagGaussianPolicy(Flux.Chain(Dense(4, 32), Dense(32, 2)), zeros(2)))
    ps = Flux.params(policy)
    v = rand(Float32, sum(length, ps))
    feat = rand(Float32, 4)
    act = rand(Float32, 2)
    f = let policy=policy, feat=feat, act=act
        () -> loglikelihood(policy, feat, act)
    hessian_vector_product(f, ps, v)

Calling test() yields:

So with some more fiddling and realizing that ForwardDiff only works with AbstractArray inputs I was able to get the above working (not yet checked for correctness). I was able to get around this through a combination of RecursiveArrayTools and Flux.fmap. I also had to modify the loglikelihood expression to get rid of the .^2 expression, which appears to be related to https://github.com/FluxML/Zygote.jl/issues/405.

While I haven't checked the result for correctness, the solution itself isn't too ugly :).

I was able to make this work without the custom seeding and instead explicitly calculating the vector product, but the solution is messier/slower so I won't bother posting it.

I'll circle back once I've cleaned things up a bit and verified the result is correct, but until then if anyone has suggestions do let me know! This is fairly critical ability for anyone doing research in ML, RL, etc.

using Flux, ForwardDiff, Zygote, RecursiveArrayTools, Random, LinearAlgebra
using Zygote: Params, Grads
using MacroTools: @forward

# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}

Flux.@functor DiagGaussianPolicy

(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)

# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
    meanact = P(feature)
    # broken (possibly related to https://github.com/FluxML/Zygote.jl/issues/405)
    #zs = ((meanact .- action) ./ exp.(P.logstd)) .^ 2
    # works
    zs = (meanact .- action) ./ exp.(P.logstd)
    zs = zs .* zs

    ll = -sum(zs)/2 - sum(P.logstd) - length(P.logstd) * log(2pi) / 2

flatgrad(gs::Grads, ps::Params) = ArrayPartition((gs[p] for p in ps if !isnothing(gs[p]))...)

function flat_hessian_vector_product(feat, act, policy, vs::ArrayPartition)
    ps = Flux.params(policy)

    i = 1
    dualpol = Flux.fmap(policy) do p
        if p in ps.params
            p = ForwardDiff.Dual{Nothing}.(p, vs.x[i])
            i += 1
    dualps = params(dualpol)

    G = let feat=feat, act=act
        function (ps)
            gs = gradient(() -> loglikelihood(dualpol, feat, act), ps)
            flatgrad(gs, ps)

    ForwardDiff.partials.(G(dualps), 1)

function test_flathvp(T::DataType=Float32)

    dobs, dact = 4, 2
    policy = DiagGaussianPolicy(Chain(Dense(dobs, 32), Dense(32, 32), Dense(32, dact)), zeros(dact))
    policy = Flux.paramtype(T, policy)

    v = ArrayPartition((rand(size(p)...) for p in params(policy))...)
    feat = rand(T, 4)
    act = rand(T, 2)

    @time flat_hessian_vector_product(feat, act, policy, v)

Can you try using Lux and ComponentArrays?

