Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Enable optimisations with `Chain`

Open theabhirath opened this issue 3 years ago • 8 comments

In long Chains in Metalhead, it is often the case that there are layers that can be reduced to identity - Dropout(p = 0) is a frequent occurrence, along with some other similar regularisation layers (DropBlock, DropPath). Currently, according to Lux's documentation, there is an option to enable and disable optimisations that can remove these and make the model a little cleaner to go through. Is there a chance something similar can be implemented for Flux?

theabhirath avatar Jun 22 '22 06:06 theabhirath

Speaking of Lux's documentation, it looks absolutely beautiful - any chance Flux would consider using a different theme?

theabhirath avatar Jun 22 '22 06:06 theabhirath

We could transition to Pollen

darsnack avatar Jun 22 '22 10:06 darsnack

As far as optimization goes, I don't think the Lux optimizations do what you're proposing. Instead, they recursively go through the Chain to wrap functions that don't adhere to the Lux interface and to delete NoOpLayers. Both of which aren't issues for Flux models.

EDIT: Okay, I see in Lux that Dropout constructors return a NoOpLayer which gets pruned by the optimization. We could do a similar thing for identity as I mentioned below. Again, what's the benefit beyond visually making the model simpler?

That being said, it wouldn't be too difficult to build the kind of optimization you're talking about using fmap (we wouldn't want the keyword interface that's limited to Chain like Lux). The main question is what kind of benefit it provides. If most of these no-ops can be optimized by the compiler itself, then this kind of optimization pass isn't super useful. Though I wouldn't be surprised if it gave a benefit on the backwards pass due to Zygote. Maybe a comparison of a manually optimized vs. un-optimized model from Metalhead would be good to have first.

darsnack avatar Jun 22 '22 12:06 darsnack

Again, what's the benefit beyond visually making the model simpler?

I'll try and benchmark to see if there's a difference. But among other things, it makes porting weights from other libraries easier despite offering a little more functionality in lieu of pre-trained weights if the user wants

theabhirath avatar Jun 22 '22 13:06 theabhirath

Just to answer a few points raised here:

  1. The optimization pass is necessary for Lux since it requires layers to follow a particular interface. For Flux, it makes little sense since it doesn't require a strict interface.
  2. No-ops would ideally be optimized away. But Zygote being zygote keeps them around. Becomes worse for Dropout where it keeps branching around :disappointed:. Though in most real-world use cases, it makes very little difference. It shows up if your model is reasonably small.

avik-pal avatar Jun 23 '22 04:06 avik-pal

2. But Zygote being zygote keeps them around. Becomes worse for Dropout where it keeps branching around

On this point, see https://github.com/FluxML/Flux.jl/pull/2005. I still think making active a type parameter as you did is cleaner, but we don't have that luxury here.

ToucheSir avatar Jun 23 '22 05:06 ToucheSir

Another optimisation that caught my eye was the flattening of nested Chains. Would that be something that would maybe help with TTFG and the backward pass times? (not sure if this is already done internally somehow, though)

theabhirath avatar Jun 23 '22 06:06 theabhirath

Not really. The short templated chains means we can often use the compiled gradient code across a model often

On Thu, Jun 23, 2022, 12:01 Abhirath Anand @.***> wrote:

Another optimisation that caught my eye was the flattening of nested Chains. Would that be something that would maybe help with TTFG and the backward pass times?

— Reply to this email directly, view it on GitHub https://github.com/FluxML/Flux.jl/issues/2004#issuecomment-1164006218, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJOZVVJ3I5MEYLD3FBUVHFLVQQAEVANCNFSM5ZO6T5OQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>

DhairyaLGandhi avatar Jun 23 '22 06:06 DhairyaLGandhi