DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

How can I integrate an RNN into an ODEProblem in Lux.jl?

Open disadone opened this issue 1 year ago • 6 comments

Hi! Just wondering how the RNN could be mixed into the ODEProblem

In flux times, it seems a Recur layer need to be created. However there is already a Recurrence in Lux.jl Training of UDEs with recurrent networks

How can Lux.jl do the job now? I self defined a GRUcell and it runs well combined with the beginner tutorial Training a Simple LSTM


using ConcreteStructs: @concrete
using Lux
using Static
using Random

IntegerType = Union{Integer,Static.StaticInteger}
BoolType = Union{StaticBool, Bool, Val{true},Val{false}}

@concrete struct FastGRUCell <:Lux.AbstractRecurrentCell
    train_state <: StaticBool
    in_dims <: IntegerType
    out_dims <: IntegerType
    init_bias
    init_weight
    init_state
    dynamics_nonlinearity
    gating_nonlinearity
    α<:AbstractFloat
    layernormQ::StaticBool
end

function FastGRUCell(
        (in_dims,out_dims)::Pair{<:Lux.IntegerType,<:Lux.IntegerType},
        Δt::T, τ::T,layernormQ::BoolType;
        train_state::BoolType=False(),
        init_weight=Lux.glorot_normal,
        init_bias=Lux.zeros32,
        init_state=zeros32,
        dynamics_nonlinearity = Lux.sigmoid_fast,
        gating_nonlinearity = Lux.tanh_fast) where T<:AbstractFloat
    init_weight = ntuple(Returns(init_weight),3)
    init_bias = ntuple(Returns(init_bias),3)
    α = Δt/τ
    return FastGRUCell(
        static(train_state),
        in_dims,out_dims,init_bias,init_weight,init_state,
        dynamics_nonlinearity,gating_nonlinearity,α,static(layernormQ)
    )
end

function Lux.initialparameters(rng::AbstractRNG,gru::FastGRUCell)
    # hidden to hidden
    Wz,Wr,Wh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.out_dims)) for init_weight in gru.init_weight)
    # input to hidden
    Uz,Ur,Uh = (Lux.init_rnn_weight(
        rng,init_weight,gru.out_dims,(gru.out_dims,gru.in_dims)) for init_weight in gru.init_weight)

    ps = (; Wz,Wr,Wh,Uz,Ur,Uh)

    biasz,biasr,biash = (Lux.init_rnn_weight(rng,init_bias,gru.out_dims,gru.out_dims) for init_bias in gru.init_bias)

    ps = merge(ps, (; biasz,biasr,biash))
    Lux.has_train_state(gru) &&  (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),)))
    return ps
end
Lux.initialstates(rng::AbstractRNG,::FastGRUCell) = (rng=Lux.Utils.sample_replicate(rng),)

function (gru::FastGRUCell{True})(x::AbstractMatrix,ps,st::NamedTuple)
    hidden_state = Lux.init_trainable_rnn_hidden_state(ps.hidden_state, x)
    return gru((x, (hidden_state,)), ps, st)
end

function (gru::FastGRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
    rng = Lux.replicate(st.rng)
    st = merge(st, (; rng))
    hidden_state = Lux.init_rnn_hidden_state(rng, gru, x)
    return gru((x, (hidden_state,)), ps, st)
end

const _FastGRUCellInputType = Tuple{
    <:AbstractMatrix, Tuple{<:AbstractMatrix}}

function (m::FastGRUCell)(
    (x,(h,))::_FastGRUCellInputType, ps,st::NamedTuple)

    Wzh =  fused_dense_bias_activation(identity,ps.Wz,h,ps.biasz)
    Wrh =  fused_dense_bias_activation(identity,ps.Wr,h,ps.biasr)
    Uzx =  fused_dense_bias_activation(identity,ps.Uz,x,nothing)
    Urx =  fused_dense_bias_activation(identity,ps.Ur,x,nothing)
    
    z = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wzh,nothing,nothing) .+ Uzx)) : (@. m.gating_nonlinearity(Wzh+Uzx))
    r = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wrh,nothing,nothing) .+ Urx)) : (@. m.gating_nonlinearity(Wrh+Urx))

    Whh = fused_dense_bias_activation(identity,ps.Wh, h .* r ,ps.biash)
    Uhh = fused_dense_bias_activation(identity,ps.Uh, x ,nothing)
    h̃ = dynamic(m.layernormQ) ? (m.dynamics_nonlinearity.(layernorm(Whh,nothing,nothing) .+ Uhh)) : (@. m.dynamics_nonlinearity(Whh+Uhh))
    h′ = @. (1-m.α * z) * h + m.α * z * h̃
    return (h′,(h′,)),st
end


# --------------------------------------------------------------------------------------------------
# adapted from https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN#Creating-a-Classifier
using Lux, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
                         for d in data[1:(dataset_size ÷ 2)]]
    anticlockwise_spirals = [reshape(
                                 d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
                             for d in data[((dataset_size ÷ 2) + 1):end]]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end

struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:fastgru_cell, :classifier)}
    fastgru_cell::L
    classifier::C
end
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        FastGRUCell(in_dims => hidden_dims, 0.01f0, 1.0f0, true), 
        Dense(hidden_dims => out_dims, sigmoid))
end

function (s::SpiralClassifier)(
    x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}

    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_fastgru = s.fastgru_cell(x_init, ps.fastgru_cell, st.fastgru_cell)

    for x in x_rest
        (y, carry), st_fastgru = s.fastgru_cell((x, carry), ps.fastgru_cell, st_fastgru)
    end

    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    st = merge(st, (classifier=st_classifier, fastgru_cell = st_fastgru))

    return vec(y), st
end

 # ----- loss
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)

# ----- training

function main(model_type)
    dev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = get_dataloaders() .|> dev

    # Create the model
    model = model_type(2, 8, 1)

    rng = Xoshiro(0)
    ps, st = Lux.setup(rng, model) |> dev

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

    for epoch in 1:25
        # Train the model
        for (x, y) in train_loader
            # x: (2,50,128), y: (128,)  # dimension time trials
            (_, loss, _, train_state) = Training.single_train_step!(
                AutoZygote(), lossfn, (x, y), train_state)

            @printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
        end

        # Validate the model
        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model(x, train_state.parameters, st_)
            loss = lossfn(ŷ, y)
            acc = accuracy(ŷ, y)
            @printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
        end
    end

    return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)

When I try to transfer my self-defined GRUcell to the tutorial MNIST Classification using Neural ODEs, I don't know how to start the job. Really appreciate If anyone could help me! Thanks!

disadone avatar Sep 28 '24 03:09 disadone

The question does not make much sense. Having hidden state which is carried over to the next call makes the equation not an ODE and thus not convergent. If you do what you do here where you init the hidden state on each call, this model is equivalent to just calling the NN that is supposed to be recurrant, and so you might as well call that NN directly. So I don't quite get what you're trying to do?

ChrisRackauckas avatar Oct 06 '24 00:10 ChrisRackauckas

Sorry for confusing. I would like to train a sequence-to-sequence model where the RNN could first derive a series of values and they are then fed into a sequence-to-sequence neuralode stuff as the inhomogeneous equation input. The weight in RNN and parameters in neuralode are trained together.

Maybe the question can be simplified as "How can I train a sequence-to-sequence neuralode with a series of inputs ?"

disadone avatar Oct 06 '24 09:10 disadone

Maybe @avik-pal has an example

ChrisRackauckas avatar Oct 12 '24 15:10 ChrisRackauckas

The weight in RNN and parameters in neuralode are trained together.

Do you mean the RNN weights and the neural network weights are shared?

avik-pal avatar Oct 13 '24 14:10 avik-pal

The weight in RNN and parameters in neuralode are trained together.

Do you mean the RNN weights and the neural network weights are shared?

No, I mean the output of RNN could be the input of neualode at each time point.

disadone avatar Oct 17 '24 10:10 disadone

But without state? Then it's not an RNN?

ChrisRackauckas avatar Oct 21 '24 09:10 ChrisRackauckas