ChainRules.jl
ChainRules.jl copied to clipboard
rrule for Dict constructor
Zygote does not seem smart enough to handle this on its own.
Pending questions:
- [ ] How to get tests working. I assume https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/195 would be required.
- [ ] Is storing projectors per-key and per-value overkill? Does a fast homogeneous type case like that for
vectmake sense? - [ ] This is currently way type unstable out of concern of handling tangent sparsity. Is that a reasonable worry or does it (/ought it, thinking about how Zygote handles Dict tangents) not show up in practice?
Not on your list, but: Is it possible/legal to have a nonzero gradient for keys of a dictionary at all?
Re projection, for ::Pairs... maybe this can't get too long. Unlike vect there is no automatic promotion, although you can do Dict{Int,Float32}(1=>2, 3=>4.0). (You can also do Dict([1=>2, 3=>4.0+im]); do the array rules need to handle that?)
Can we polish this up and merge it, maybe with just some quick Zygote test while waiting for https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/195? The current implementation looks fine to me. It would improve a lot the life of people trying to work with dictionaries, at the moment I can't find a way to construct one in a gradient context.
I don't foresee having the bandwidth or necessary knowledge to take this further, so if anyone wants to take over the work in this PR please don't hesitate to do so :)