Yota.jl icon indicating copy to clipboard operation
Yota.jl copied to clipboard

Flux Integration

Open MikeInnes opened this issue 5 years ago • 2 comments

I'm curious if you'd be interested in making Yota compatible with Flux layers and optimisers; then Yota could be used in place of Tracker for models without control flow.

Zygote does this by inserting calls to unwrap which strip away Flux tracking (this of course won't be necessary when we get rid of Tracker).

MikeInnes avatar Feb 26 '19 11:02 MikeInnes

Yes, I thought about it, but I don't see a straightforward way to do it just yet - unwrapping tracked data helps in finding gradients, but how do you use them after that?

For example, let's take _update_params!(opt, xs) - it expects each x to have .data and .grad properties, e.g. be a Tracked* type. Zygote by itself doesn't use tracked data, do you just record found gradients back to original tracked variables? If so, do you have a pointer to the relevant piece of code?

dfdx avatar Feb 28 '19 07:02 dfdx

I spent some time refactoring this stuff for exactly this reason. We now have this version of update which only requires a param -> grad mapping. It should be pretty easy to get that out of Yota.

The Param API is somewhat transitional, it represents a compromise between what Flux and Yota/Zygote can expose, but the idea is to eventually get rid of it in favour of more functional optimisers. Once that happens it should be possible to use Yota's native API with Flux.

MikeInnes avatar Mar 06 '19 17:03 MikeInnes