Flux.jl
Flux.jl copied to clipboard
Hessian vector products with moderately complex models
I'm attempting to compute Hessian vector products for use with RL algorithms like Natural Policy Gradient or TRPO, but have been entirely unsuccessful.
Following https://github.com/FluxML/Zygote.jl/issues/115, https://github.com/JuliaDiffEq/SparseDiffTools.jl, and elsewhere I was able to compute HVPs for simple models parameterized by a single Array
, but the following appears to have issues inferring the type of Dual
.
Any help would be greatly appreciated! :)
# Zygote v0.4.1, Flux v0.10.0, ForwardDiff v0.10.7, DiffRules 0.1.0, ZygoteRules 0.2.0
# Julia Version 1.3.0
# Commit 46ce4d7933 (2019-11-26 06:09 UTC)
# Platform Info:
# OS: Linux (x86_64-pc-linux-gnu)
# CPU: Intel(R) Core(TM) i9-7960X CPU @ 2.80GHz
# WORD_SIZE: 64
# LIBM: libopenlibm
# LLVM: libLLVM-6.0.1 (ORCJIT, skylake)
using Flux, ForwardDiff, Zygote
using LinearAlgebra
# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
meanNN::M
logstd::L
end
Flux.@functor DiagGaussianPolicy
(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)
# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
meanact = P(feature)
ll = -length(P.logstd) * log(2pi) / 2
for i = 1:length(action)
ll -= ((meanact[i] - action[i]) / exp(P.logstd[i]))^2 / 2
ll -= P.logstd[i]
end
ll
end
function flatgrad(f, ps)
gs = Zygote.gradient(f, ps)
vcat([vec(gs[p]) for p in ps]...)
end
Base.length(ps::Params) = 228 #sum(length, ps)
Base.size(ps::Params) = (228, ) #(length(ps), )
Base.eltype(ps::Params) = Float32
function hessian_vector_product(f,ps,v)
g = let f=f
ps -> flatgrad(f, ps)::Vector{Float32}
end
gvp = let g=g, v=v
ps -> (g(ps)⋅v)::Vector{Float32}
end
Zygote.forward_jacobian(gvp, ps)[2]
end
function test()
policy = Flux.paramtype(Float32, DiagGaussianPolicy(Flux.Chain(Dense(4, 32), Dense(32, 2)), zeros(2)))
ps = Flux.params(policy)
v = rand(Float32, sum(length, ps))
feat = rand(Float32, 4)
act = rand(Float32, 2)
f = let policy=policy, feat=feat, act=act
() -> loglikelihood(policy, feat, act)
end
hessian_vector_product(f, ps, v)
end
Calling test()
yields:
an_dual.
Stacktrace:
[1] throw_cannot_dual(::Type) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:36
[2] ForwardDiff.Dual{Nothing,Any,12}(::Array{Float32,2}, ::ForwardDiff.Partials{12,Any}) at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:18
[3] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:55 [inlined]
[4] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:62 [inlined]
[5] Dual at /home/colinxs/.julia/packages/ForwardDiff/DVizx/src/dual.jl:68 [inlined]
[6] (::Zygote.var"#1565#1567"{12,Int64})(::Array{Float32,2}, ::Int64) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:8
[7] (::Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}})(::Tuple{Array{Float32,2},Int64}) at ./generator.jl:36
[8] iterate at ./generator.jl:47 [inlined]
[9] collect(::Base.Generator{Base.Iterators.Zip{Tuple{Params,UnitRange{Int64}}},Base.var"#3#4"{Zygote.var"#1565#1567"{12,Int64}}}) at ./array.jl:622
[10] map at ./abstractarray.jl:2155 [inlined]
[11] seed at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:7 [inlined] (repeats 2 times)
[12] forward_jacobian(::var"#340#342"{var"#339#341"{var"#343#344"{DiagGaussianPolicy{Chain{Tuple{Dense{typeof(identity),Array{Float32,2},Array{Float32,1}},Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}}},Array{Float32,1}},
Array{Float32,1},Array{Float32,1}}},Array{Float32,1}}, ::Params, ::Val{12}) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:23
[13] forward_jacobian(::Function, ::Params) at /home/colinxs/.julia/packages/Zygote/8dVxG/src/lib/forward.jl:40
[14] hessian_vector_product(::Function, ::Params, ::Array{Float32,1}) at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:52
[15] test() at /home/colinxs/workspace/dev/SharedExperiments/lyceum/hvp.jl:64
[16] top-level scope at REPL[41]:1
So with some more fiddling and realizing that ForwardDiff only works with AbstractArray
inputs I was able to get the above working (not yet checked for correctness). I was able to get around this through a combination of RecursiveArrayTools
and Flux.fmap
. I also had to modify the loglikelihood expression to get rid of the .^2
expression, which appears to be related to https://github.com/FluxML/Zygote.jl/issues/405.
While I haven't checked the result for correctness, the solution itself isn't too ugly :).
I was able to make this work without the custom seeding and instead explicitly calculating the vector product, but the solution is messier/slower so I won't bother posting it.
I'll circle back once I've cleaned things up a bit and verified the result is correct, but until then if anyone has suggestions do let me know! This is fairly critical ability for anyone doing research in ML, RL, etc.
using Flux, ForwardDiff, Zygote, RecursiveArrayTools, Random, LinearAlgebra
using Zygote: Params, Grads
using MacroTools: @forward
# A Gaussian policy with diagonal covariance
struct DiagGaussianPolicy{M,L<:AbstractVector}
meanNN::M
logstd::L
end
Flux.@functor DiagGaussianPolicy
(policy::DiagGaussianPolicy)(features) = policy.meanNN(features)
# log(pi_theta(a | s))
function loglikelihood(P::DiagGaussianPolicy, feature::AbstractVector, action::AbstractVector)
meanact = P(feature)
# broken (possibly related to https://github.com/FluxML/Zygote.jl/issues/405)
#zs = ((meanact .- action) ./ exp.(P.logstd)) .^ 2
# works
zs = (meanact .- action) ./ exp.(P.logstd)
zs = zs .* zs
ll = -sum(zs)/2 - sum(P.logstd) - length(P.logstd) * log(2pi) / 2
ll
end
flatgrad(gs::Grads, ps::Params) = ArrayPartition((gs[p] for p in ps if !isnothing(gs[p]))...)
function flat_hessian_vector_product(feat, act, policy, vs::ArrayPartition)
ps = Flux.params(policy)
i = 1
dualpol = Flux.fmap(policy) do p
if p in ps.params
p = ForwardDiff.Dual{Nothing}.(p, vs.x[i])
i += 1
end
p
end
dualps = params(dualpol)
G = let feat=feat, act=act
function (ps)
gs = gradient(() -> loglikelihood(dualpol, feat, act), ps)
flatgrad(gs, ps)
end
end
ForwardDiff.partials.(G(dualps), 1)
end
function test_flathvp(T::DataType=Float32)
Random.seed!(1)
dobs, dact = 4, 2
policy = DiagGaussianPolicy(Chain(Dense(dobs, 32), Dense(32, 32), Dense(32, dact)), zeros(dact))
policy = Flux.paramtype(T, policy)
v = ArrayPartition((rand(size(p)...) for p in params(policy))...)
feat = rand(T, 4)
act = rand(T, 2)
@time flat_hessian_vector_product(feat, act, policy, v)
end
Can you try using Lux and ComponentArrays?