Zygote.jl
Zygote.jl copied to clipboard
`unbroadcast` error due to overly strict types
Frequently I get something like
TypeError: in typeassert, expected Tuple{T, T} where T, got a value of type Tuple{Array..., Nothing}
caused by https://github.com/FluxML/Zygote.jl/blob/e0af1a814b9c1b652861eda5db27ddec13a28d16/src/lib/broadcast.jl#L69 from NTuple enforcing types when some gradient branch is nothing
I can sometimes fix it by rewriting broadcast code or hacks like multiplying a branch by 0 and then adding it to loss. But can we just change NTuple{length(x)} to Tuple in that line?
That does look wrong, but probably Tuple{Vararg{Any,N}} will be better for type stability than Tuple, while still allowing nothing:
julia> NTuple{3}([1, 2, 3, 4])
(1, 2, 3)
julia> NTuple{3}([1, 2, nothing, 4])
ERROR: TypeError: in typeassert, expected Tuple{T, T} where T, got a value of type Tuple{Int64, Nothing}
Stacktrace:
...
[3] (Tuple{T, T, T} where T)(itr::Vector{Union{Nothing, Int64}})
@ Base ./tuple.jl:455
julia> Tuple{Vararg{Any,3}}([1, 2, nothing, 4])
(1, 2, nothing)
julia> @code_warntype Tuple{Vararg{Any,3}}([1, 2, nothing, 4])
MethodInstance for Tuple{Any, Any, Any}(::Vector{Union{Nothing, Int64}})
from (::Type{T})(itr) where T<:Tuple @ Base tuple.jl:455
Static Parameters
T = Tuple{Any, Any, Any}
...
Body::Tuple{Union{Nothing, Int64}, Union{Nothing, Int64}, Union{Nothing, Int64}}
...
julia> @code_warntype Tuple([1, 2, nothing, 4])
...
Body::Tuple{Vararg{Union{Nothing, Int64}}}