DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

DiffEqFlux Layers don't satisfy Lux API

Open avik-pal opened this issue 2 years ago • 14 comments

The DiffEqFlux Layers need to satisfy https://lux.csail.mit.edu/dev/api/core/#Lux.AbstractExplicitLayer else the parameters/states returned from Lux.setup be incorrect. As pointed out in slack

julia> ps, st = Lux.setup(rng, Chain(node,Dense(2=>3)))
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.11987843 -0.1679378; 0.36991563 0.41324985; 0.73272866 0.7062624], bias = Float32[0.0; 0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

ps.layer_1 should not be an empty NamedTuple

https://lux.csail.mit.edu/dev/manual/interface/ -- is the most recent manual for the interface

avik-pal avatar Jun 18 '22 07:06 avik-pal

Should be easy

initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model) 
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)
parameterlength(node::NeuralODE) = parameterlength(node.model)
statelength(node::NeuralODE) = statelength(node.model)

To make setup work not only for Chain but also directly on NeuralODE, we need to add

function setup(rng::AbstractRNG, node::NeuralODE)
    return (initialparameters(rng, node), initialstates(rng, node))
end

YichengDWu avatar Jun 19 '22 04:06 YichengDWu

Are you supposed to overload setup? I assume that should just follow from the interface.

ChrisRackauckas avatar Jun 19 '22 12:06 ChrisRackauckas

You just need to define

initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model) 
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)

avik-pal avatar Jun 19 '22 16:06 avik-pal

We should put an abstract type on all of the AbstractNeuralDE types and then overload from there.

ChrisRackauckas avatar Jun 19 '22 16:06 ChrisRackauckas

You just need to define

initialparameters(rng::AbstractRNG, node::NeuralODE) = initialparameters(rng, node.model) 
initialstates(rng::AbstractRNG, node::NeuralODE) = initialstates(rng, node.model)

For it to work yes. Would it be nicer if the number of parameters could be printed automatically?

YichengDWu avatar Jun 19 '22 20:06 YichengDWu

Are you supposed to overload setup? I assume that should just follow from the interface.

I was assuming NeuralODE was not a subtype of AbstractExplicitLayer. Should be nonnecessary if you are going to subtype it

YichengDWu avatar Jun 19 '22 20:06 YichengDWu

No even if you are not subtying initialparameters and initialstates are the only functions that need to be mandatorily implemented, parameterlength and statelength are optional. setup should never be extended

avik-pal avatar Jun 20 '22 18:06 avik-pal

I would appreciate it if you could help me understand two questions:

  1. Is it still mandatory to implement initialstates if I just have one layer and just need to return NameTuple()? I have implemented some layers without it. Looks like they are just calling initialstates(::AbstractRNG, ::Any) = NamedTuple() in the source code.
  2. What are the bad consequences of extending setup?

YichengDWu avatar Jun 20 '22 18:06 YichengDWu

It is meant to satisfy an interface.

  1. You are right, the default for initialstates is NamedTuple(), but this is undocumented so this can be changed without it being considered breaking.
  2. Extending setup is not going to solve problems for most people and sets false expectation. For example, if you extend setup for a layer which is contained inside another layer. Calling Lux.setup on the outer layer, will cause the parameters and states for the internal custom layer to have empty parameters and states.

avik-pal avatar Jun 20 '22 18:06 avik-pal

Highly appreciate the clarification you made.

YichengDWu avatar Jun 20 '22 19:06 YichengDWu

Flux doesn't care about the subtyping but Lux does, so we should subtype for Lux and then also make it a functor and we're 👍.

ChrisRackauckas avatar Jun 20 '22 19:06 ChrisRackauckas

Copying over from https://github.com/SciML/DiffEqFlux.jl/pull/735. All should be an AbstractExplicitLayer, which means they should do things exactly like Dense. They should have one state, take in a state, and return a state. They should take in a neural network definition and give you back a state from setup. Basically, it should act exactly like Dense does, and be able to perfectly swap in without any other code changes, and if not it's wrong. The only thing that should be different is the constructor for the layer.

@Abhishek-1Bhatt let me know if you need me to do the first one.

ChrisRackauckas avatar Jun 22 '22 08:06 ChrisRackauckas

Once it gets built, http://lux.csail.mit.edu/previews/PR70/manual/interface should describe the recommended Lux Interface. For DiffEqFlux, everything should really be a subtype of http://lux.csail.mit.edu/stable/api/core/#Lux.AbstractExplicitContainerLayer, and there would be no need to define initialparameters and initialstates. (Just a small heads up there will be a small breaking change for the Container Layers in v0.5 (which is still far out) )

avik-pal avatar Jun 23 '22 05:06 avik-pal

Ahh yes, I had thought about it sometime ago https://julialang.slack.com/archives/C7T968HRU/p1655536943724979?thread_ts=1655535510.205359&cid=C7T968HRU but we didn't discuss it so ended up subtyping to AbstractExplicitLayer

ba2tro avatar Jun 23 '22 05:06 ba2tro

Done in https://github.com/SciML/DiffEqFlux.jl/pull/750

ChrisRackauckas avatar Jan 17 '23 14:01 ChrisRackauckas