LazyArrays.jl icon indicating copy to clipboard operation
LazyArrays.jl copied to clipboard

Zygote compat is lacking

Open torfjelde opened this issue 2 years ago • 11 comments

Zygote doesn't interact too nicely with LazyArrays.jl it seems, e.g.:

julia> f(x) = sum(BroadcastArray(exp, x))
f (generic function with 1 method)

julia> Zygote.gradient(f, randn(10))
ERROR: type Array has no field f
Stacktrace:
  [1] adjoint
    @ ~/.julia/packages/Zygote/AS0Go/src/lib/lib.jl:229 [inlined]
  [2] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
  [3] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:50 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(LazyArrays.call), ::ArrayLayouts.DenseColumnMajor, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [5] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:52 [inlined]
  [6] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:82 [inlined]
  [7] _pullback
    @ ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:57 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::Type{BroadcastArray}, ::typeof(exp), ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./REPL[48]:1 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(f), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [11] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [12] pullback
    @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [13] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [14] top-level scope
    @ REPL[50]:1

julia> g(x) = sum(LazyArray(@~ exp.(x)))
g (generic function with 1 method)

julia> Zygote.gradient(g, randn(10))
ERROR: MethodError: no method matching LazyArray(::Vector{Float64})
Closest candidates are:
  LazyArray(::Base.Broadcast.Broadcasted) at ~/.julia/packages/LazyArrays/NYra8/src/lazybroadcasting.jl:35
  LazyArray(::Applied) at ~/.julia/packages/LazyArrays/NYra8/src/lazyapplying.jl:193
Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0 [inlined]
 [2] _pullback(ctx::Zygote.Context{false}, f::Type{LazyArray}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:9
 [3] _pullback
   @ ./REPL[53]:1 [inlined]
 [4] _pullback(ctx::Zygote.Context{false}, f::typeof(g), args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface2.jl:0
 [5] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:44
 [6] pullback
   @ ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:42 [inlined]
 [7] gradient(f::Function, args::Vector{Float64})
   @ Zygote ~/.julia/packages/Zygote/AS0Go/src/compiler/interface.jl:96
 [8] top-level scope
   @ REPL[54]:1

The first error can be "fixed" (I'm not entirely certain if this is the right way to go about it) by defining a chain rule:

julia> using ChainRulesCore

julia> function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::Type{LazyArrays.BroadcastArray}, f, args...)
           return ChainRulesCore.rrule_via_ad(config, Broadcast.broadcasted, f, args...)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(f, randn(10))
([0.24117702568683322, 2.478340448616497, 2.433266795642693, 1.6163793920298133, 1.8859252985478665, 3.9539878829654223, 1.2578105524502685, 0.48545348574922, 0.8710494256114425, 3.0853524634917076],)

Maybe the rest can be addressed this way too.

Are rules from CRC something that would be welcomed?

torfjelde avatar Jan 15 '23 21:01 torfjelde

Hmm.... that's a good question.... I'm usually hesitant to add "*Core.jl" dependencies because a lot of them are of questionable usage but ChainRulesCore.jl might be an exception.

One alternative solution is to make a glue package a la FastTransformsForwardDiff. (I'm wondering whether that should have been FastTransformsChainRulesCore.jl...)

dlfivefifty avatar Jan 16 '23 17:01 dlfivefifty

Either alternative is okay with me:)

torfjelde avatar Jan 16 '23 17:01 torfjelde

You just say which alternative you prefer, and I can try to contribute towards it.

torfjelde avatar Jan 16 '23 20:01 torfjelde

Let's put it in a separate package for now so we can work out the kinks. We can always merge it back here (in the event there's a good reason to have it).

dlfivefifty avatar Jan 16 '23 20:01 dlfivefifty

It seems this is a good use case for weak deps. Some packages already started moving ChainRules definition to weak deps. The definitions would be loaded only on Julia >= 1.9 (if you don't want to uae Requires on older Julia versions) but I think it would be the better long-term solution.

devmotion avatar Jan 18 '23 08:01 devmotion

It woul suck if we'd have to wait until Julia 1.9 before we could make use of this though :confused:

torfjelde avatar Jan 18 '23 11:01 torfjelde

I assume it already works with the beta version, so I think you can already use it without compiling julia.

devmotion avatar Jan 18 '23 12:01 devmotion

Can we do a separate package that works now, and becomes a weak dependency in Julia v1.9?

dlfivefifty avatar Jan 18 '23 13:01 dlfivefifty

If a weak dependency is loaded, an extension (usually a single file) in the ext subfolder is loaded (and precompiled, in contrast to the Requires hacks!). AFAIK there are no separate packages involved or loaded in the extension apart from the weak dependency and the package + hard dependencies, and making the glue package a hard dependency would defeat its purpose. An example is shown in this PR: https://github.com/JuliaMath/ChangesOfVariables.jl/pull/12

devmotion avatar Jan 18 '23 13:01 devmotion

I see. I think a weak dependency hear would be fine. I would suggest forgetting the separate project and just requiring v1.9

dlfivefifty avatar Jan 18 '23 14:01 dlfivefifty

We use weak deps for ChangesOfVariables.jl now, and it works like a charm on Julia v1.9:

julia> @time_imports import ChangesOfVariables
      0.6 ms  ChangesOfVariables

julia> @time_imports import ChainRulesCore
      0.1 ms  Compat
     58.9 ms  ChainRulesCore
      0.4 ms  ChangesOfVariables → ChainRulesCoreExt

oschulz avatar Jan 26 '23 19:01 oschulz