Zygote.jl
Zygote.jl copied to clipboard
Typed Matrix literal is not differentiable.
Trying to differentiate a typed array literal fails with the error that mutation is not supported.
Julia Version: 1.11.5 and Zygote v0.7.7
Minimal workable Example:
using Zygote
T(x) = Float32[1 0 0 x[1]][4]
gradient(T, [42.0f0])
Stacktrace:
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/array.jl:70
[3] (::Zygote.var"#553#554"{Matrix{Float32}})(::Nothing)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/array.jl:82
[4] (::Zygote.var"#2643#back#555"{Zygote.var"#553#554"{Matrix{Float32}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:72
[5] hvcat_fill!
@ ./abstractarray.jl:2238 [inlined]
[6] (::Zygote.Pullback{Tuple{typeof(Base.hvcat_fill!), Matrix{…}, Tuple{…}}, Any})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#672#674"{…}}, ChainRules.var"#671#673"{Matrix{…}, Float32, Tuple{…}}})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[7] typed_hcat
@ ./abstractarray.jl:1645 [inlined]
[8] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{…}, ChainRules.var"#671#673"{…}})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[9] T
@ ./REPL[3]:1 [inlined]
[10] (::Zygote.Pullback{Tuple{typeof(T), Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface2.jl:0
[11] (::Zygote.var"#88#89"{Zygote.Pullback{Tuple{typeof(T), Vector{Float32}}, Tuple{Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}, Zygote.Pullback{Tuple{…}, Tuple{…}}}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:97
[12] gradient(f::Function, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/compiler/interface.jl:154
[13] top-level scope
@ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.
This syntax calls typed_hcat, and ideally there would be a rule for that in ChainRules.jl, probably a mild adaptation of the existing rule for hcat:
julia> Meta.@lower Float32[1 0 0 x[1]]
:($(Expr(:thunk, CodeInfo(
@ none within `top-level scope`
1 ─ %1 = Float32
│ %2 = x
│ %3 = Base.getindex(%2, 1)
│ %4 = Base.typed_hcat(%1, 1, 0, 0, %3)
└── return %4
))))
julia> Zygote.gradient(x -> sum(abs2, hcat(1, 2, x)), 3.0)
(6.0,)
julia> Zygote.gradient(x -> sum(abs2, Base.typed_hcat(Float32, 1, 2, x)), 3.0)
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
...
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] _throw_mutation_error(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/wfLOG/src/lib/array.jl:70
...
[5] hvcat_fill!
@ ./abstractarray.jl:2238 [inlined]
...
[7] typed_hcat
@ ./abstractarray.jl:1645 [inlined]
Note that there are similar typed_hvcat and typed_hvncat functions too:
julia> Meta.@lower Float32[1 2; 3 4]
:($(Expr(:thunk, CodeInfo(
@ none within `top-level scope`
1 ─ %1 = Float32
│ %2 = Core.tuple(2, 2)
│ %3 = Base.typed_hvcat(%1, %2, 1, 2, 3, 4)
└── return %3
))))
julia> Meta.@lower Float32[1;;;]
:($(Expr(:thunk, CodeInfo(
@ none within `top-level scope`
1 ─ %1 = Base.typed_hvncat(Float32, 3, 1)
└── return %1
))))
Relevant issue: https://github.com/JuliaDiff/ChainRules.jl/issues/743