ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

rrule for Dict constructor

Open ToucheSir opened this issue 3 years ago • 1 comments

Zygote does not seem smart enough to handle this on its own.

Pending questions:

  • [ ] How to get tests working. I assume https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/195 would be required.
  • [ ] Is storing projectors per-key and per-value overkill? Does a fast homogeneous type case like that for vect make sense?
  • [ ] This is currently way type unstable out of concern of handling tangent sparsity. Is that a reasonable worry or does it (/ought it, thinking about how Zygote handles Dict tangents) not show up in practice?

ToucheSir avatar Jul 18 '22 02:07 ToucheSir

Not on your list, but: Is it possible/legal to have a nonzero gradient for keys of a dictionary at all?

Re projection, for ::Pairs... maybe this can't get too long. Unlike vect there is no automatic promotion, although you can do Dict{Int,Float32}(1=>2, 3=>4.0). (You can also do Dict([1=>2, 3=>4.0+im]); do the array rules need to handle that?)

mcabbott avatar Jul 18 '22 03:07 mcabbott

Can we polish this up and merge it, maybe with just some quick Zygote test while waiting for https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/195? The current implementation looks fine to me. It would improve a lot the life of people trying to work with dictionaries, at the moment I can't find a way to construct one in a gradient context.

CarloLucibello avatar Jun 14 '23 17:06 CarloLucibello

I don't foresee having the bandwidth or necessary knowledge to take this further, so if anyone wants to take over the work in this PR please don't hesitate to do so :)

ToucheSir avatar Jun 14 '23 19:06 ToucheSir