Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

RFC: add a supertype to layers

Open mcabbott opened this issue 3 years ago • 5 comments

This proposes to gives Flux's layer types various supertypes.

One reason to like this is that it simplifies the use of show. If you have the same supertype as Chain, you will be unfolded at top level like it is. No mystery functions to overload. Closes https://github.com/FluxML/Flux.jl/pull/1932, closes #2044

Edit: 568af9b goes further: We can define functor for this abstract type, eliminating the need to call @functor. (It's a pretty mysterious macro, and you can't @macroexpand it to have a look.) We can also define trainable for some abstract types; maybe that's also less mysterious.

Another is this: Flux.gpu and CUDA.cu both recursively move things, the latter via Adapt not Functors. Which means cu does not preserve tied weights. But if we can rely on the the shared arrays both living within a struct whose type we own (like Chain) then we can convert cu to something Functors-based at that boundary. (It's enough for one outer layer to do this -- using weird Dense-like layers marked only with @functor within a Chain is fine.)

Note that this supertype is entirely optional. The PR does not change the fact that functor is how Flux walks models, etc, and so it does not restrict how you can customise things. It only aims to make it easy to opt-in to the standard behaviours, without a zoo of weird macros.

mcabbott avatar Jul 27 '22 02:07 mcabbott

Since this has now been put on the next milestone, I should chime in. I personally do not like having a super-type. My reasons:

  • It precludes someone from using their own type hierarchy. There are real examples of this: InvertibleNetworks.jl and Mill.jl. Or any code where we write functions of the form f: model -> model' will reasonably want to dispatch on the model type. Depending on what exactly f and model are, using a type hierarchy could be very useful here.
  • It's a sharp change from Flux's approach to this problem in the past. If we were getting something of important value here, then it might be worth it. But it seems like this is one implementation that is subjectively more convenient than a macro.
  • One positive for this approach is that it is more natural to Julia users than a macro. I agree in general with this for types, but I don't think it holds as much weight in our case. Any user wanting to implement a custom layer is going to have to look at the documentation...these types aren't so intuitive that they should naturally come to mind. I don't see how learning that you need to @functor MyLayer is so much worse at this point.
  • I agree that @functor is weird and confusing. I think Flux ought to provide something more friendly like @layer which does the Functor stuff but also show stuff and more.
  • The current type hierarchy—PartialTrainLayer{which} <: SimpleLayer <: AbstractLayer—is already getting messy. Julia's type system just isn't good for hanging all this information on, IMO (see the never-ending traits discussions). What happens for a layer which <: ContainerLayer but also has trainable parameters of its own?
  • A super-type makes adopting the defaults easy, but it doesn't avoid needing to know and override trainable, show, etc. The non-default cases are still just as annoying. I'd rather work on a solution that makes those cases better. Hiding the default case is not so hard of a problem.
  • I don't see us ever moving away from something Functors-like. We can nix Functors, but we will still have the need to walk and recurse nested structs. Having multiple levels/ways of opting into being a Flux layer seems confusing to me.

It seems like the main motivation for this PR originally (though it has grown) is to make _big_show easier to opt-into. Once we adopt a type hierarchy, it will be hard to change our minds. We should think carefully about what it should be and what functionality should hang on top of it. If default show is what's at stake here, then a more minimal change might be easier for right now. Two options:

  • Define something like @layer which does @functor + show. This seems like it has no downsides. We mask the confusing @functor, and we can always introduce types on top of it later. If we want to remove it in favor of types, it seems easier to deprecate.
  • Introduce just SimpleLayer and ContainerLayer for the purpose of show only. Leave trainable, etc. as is for now.

darsnack avatar Aug 23 '22 03:08 darsnack

I agree the present state of this PR is a sort-of "maximial" look at what's possible, to see what clashes result, and as noted I'm not sure that including trainable is a great idea. The "minimal" one would be just SimpleLayer and ContainerLayer.

Figuring out whether things clash with packages is one goal here. It looks at first glance that InvertibleNetworks.jl has a type hierarchy which could subtype Flux.AbstractLayer, but perhaps no finer, not sure, anyone know more details? Mill.jl I'm less sure -- anyone know? Note that both packages mentioned do load Flux.

I also agree that some kind of Flux.@layer macro is an alternative option. Maybe a smaller change from @functor / @treelike. (Maybe someone can make a PR to try it out and see what problems arise; @functor is a weird macro with eval but perhaps it need not be.) To be useful for show it would somehow need to take options to make the Chain-like / Dense-like distinction, default to Dense-like perhaps. Maybe it ought to take options to control trainable too.

Unlike a supertype, the new macro could be obligatory. The idea of translating cu to an fmap walk at the outermost layer would then be easy.

Macros are always a bit opaque. You could argue that FluxML has probably gone a little too far in avoiding types, e.g. train! took 4 arguments of ::Any, good luck remembering... but now methods(Flux.train!) makes it easier not to mess up. Although it's certainly possible to go too far the other way.

(Don't read too much into the milestone, just trying to separate future ideas from bugfixes a bit.)

mcabbott avatar Aug 23 '22 03:08 mcabbott

I don’t think a macro that defines a show method should be obligatory because then you will get method overwrites if you define your own show method (which is one nice thing about the abstract type- it’s a fallback then).

ericphanson avatar Aug 23 '22 09:08 ericphanson

Good point. The messy thing about macros will be the order in which stuff is called. Something like

struct MyLayer
    foo
end

Base.show(..., ::MyLayer) = ...

@layer MyLayer

seems like a simple error to run into. The current @functor only defines functor which is likely not to run into this issue.

I think with something like https://github.com/FluxML/Functors.jl/issues/41, we remove trainable etc. completely. Then we are left with show which can be opt-in:

@layer MyLayer prettyprint=:simple

I was going to prototype the Functors issue anyways, so I can also test out the @layer stuff.

Another option is to make @layer apply on the struct definition itself. This will ensure that any methods that are defined later like a custom show will happen after the default.

Alternatively, introducing the hierarchy with just SimpleLayer and ContainerLayer only for non-parameter related stuff like show seems reasonable to me. At least there will be a clean demarcation between Functors stuff and other stuff.

darsnack avatar Aug 23 '22 12:08 darsnack

Yes the macro needs at least a way to choose Chain-like vs. Dense-like printing. It it not only methods for show that do this, but possibly one option should apply no methods to show. I don't think we should ever rely on order of overwriting methods.

Macros which act on structs seem a step more opaque. And also more complicated to write, e.g. how should it interact with @kwdef?

mcabbott avatar Aug 23 '22 12:08 mcabbott