DiffEqFlux.jl
DiffEqFlux.jl copied to clipboard
How can I integrate an RNN into an ODEProblem in Lux.jl?
Hi!
Just wondering how the RNN could be mixed into the ODEProblem
In flux times, it seems a Recur layer need to be created. However there is already a Recurrence in Lux.jl
Training of UDEs with recurrent networks
How can Lux.jl do the job now?
I self defined a GRUcell and it runs well combined with the beginner tutorial Training a Simple LSTM
using ConcreteStructs: @concrete
using Lux
using Static
using Random
IntegerType = Union{Integer,Static.StaticInteger}
BoolType = Union{StaticBool, Bool, Val{true},Val{false}}
@concrete struct FastGRUCell <:Lux.AbstractRecurrentCell
train_state <: StaticBool
in_dims <: IntegerType
out_dims <: IntegerType
init_bias
init_weight
init_state
dynamics_nonlinearity
gating_nonlinearity
α<:AbstractFloat
layernormQ::StaticBool
end
function FastGRUCell(
(in_dims,out_dims)::Pair{<:Lux.IntegerType,<:Lux.IntegerType},
Δt::T, τ::T,layernormQ::BoolType;
train_state::BoolType=False(),
init_weight=Lux.glorot_normal,
init_bias=Lux.zeros32,
init_state=zeros32,
dynamics_nonlinearity = Lux.sigmoid_fast,
gating_nonlinearity = Lux.tanh_fast) where T<:AbstractFloat
init_weight = ntuple(Returns(init_weight),3)
init_bias = ntuple(Returns(init_bias),3)
α = Δt/τ
return FastGRUCell(
static(train_state),
in_dims,out_dims,init_bias,init_weight,init_state,
dynamics_nonlinearity,gating_nonlinearity,α,static(layernormQ)
)
end
function Lux.initialparameters(rng::AbstractRNG,gru::FastGRUCell)
# hidden to hidden
Wz,Wr,Wh = (Lux.init_rnn_weight(
rng,init_weight,gru.out_dims,(gru.out_dims,gru.out_dims)) for init_weight in gru.init_weight)
# input to hidden
Uz,Ur,Uh = (Lux.init_rnn_weight(
rng,init_weight,gru.out_dims,(gru.out_dims,gru.in_dims)) for init_weight in gru.init_weight)
ps = (; Wz,Wr,Wh,Uz,Ur,Uh)
biasz,biasr,biash = (Lux.init_rnn_weight(rng,init_bias,gru.out_dims,gru.out_dims) for init_bias in gru.init_bias)
ps = merge(ps, (; biasz,biasr,biash))
Lux.has_train_state(gru) && (ps = merge(ps, (hidden_state=gru.init_state(rng, gru.out_dims),)))
return ps
end
Lux.initialstates(rng::AbstractRNG,::FastGRUCell) = (rng=Lux.Utils.sample_replicate(rng),)
function (gru::FastGRUCell{True})(x::AbstractMatrix,ps,st::NamedTuple)
hidden_state = Lux.init_trainable_rnn_hidden_state(ps.hidden_state, x)
return gru((x, (hidden_state,)), ps, st)
end
function (gru::FastGRUCell{False})(x::AbstractMatrix, ps, st::NamedTuple)
rng = Lux.replicate(st.rng)
st = merge(st, (; rng))
hidden_state = Lux.init_rnn_hidden_state(rng, gru, x)
return gru((x, (hidden_state,)), ps, st)
end
const _FastGRUCellInputType = Tuple{
<:AbstractMatrix, Tuple{<:AbstractMatrix}}
function (m::FastGRUCell)(
(x,(h,))::_FastGRUCellInputType, ps,st::NamedTuple)
Wzh = fused_dense_bias_activation(identity,ps.Wz,h,ps.biasz)
Wrh = fused_dense_bias_activation(identity,ps.Wr,h,ps.biasr)
Uzx = fused_dense_bias_activation(identity,ps.Uz,x,nothing)
Urx = fused_dense_bias_activation(identity,ps.Ur,x,nothing)
z = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wzh,nothing,nothing) .+ Uzx)) : (@. m.gating_nonlinearity(Wzh+Uzx))
r = dynamic(m.layernormQ) ? (m.gating_nonlinearity.(layernorm(Wrh,nothing,nothing) .+ Urx)) : (@. m.gating_nonlinearity(Wrh+Urx))
Whh = fused_dense_bias_activation(identity,ps.Wh, h .* r ,ps.biash)
Uhh = fused_dense_bias_activation(identity,ps.Uh, x ,nothing)
h̃ = dynamic(m.layernormQ) ? (m.dynamics_nonlinearity.(layernorm(Whh,nothing,nothing) .+ Uhh)) : (@. m.dynamics_nonlinearity(Whh+Uhh))
h′ = @. (1-m.α * z) * h + m.α * z * h̃
return (h′,(h′,)),st
end
# --------------------------------------------------------------------------------------------------
# adapted from https://lux.csail.mit.edu/stable/tutorials/beginner/3_SimpleRNN#Creating-a-Classifier
using Lux, JLD2, MLUtils, Optimisers, Zygote, Printf, Random, Statistics
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [reshape(d[1][:, 1:sequence_length], :, sequence_length, 1)
for d in data[1:(dataset_size ÷ 2)]]
anticlockwise_spirals = [reshape(
d[1][:, (sequence_length + 1):end], :, sequence_length, 1)
for d in data[((dataset_size ÷ 2) + 1):end]]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(collect.((x_train, y_train)); batchsize=128, shuffle=true),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false))
end
struct SpiralClassifier{L, C} <: Lux.AbstractLuxContainerLayer{(:fastgru_cell, :classifier)}
fastgru_cell::L
classifier::C
end
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
FastGRUCell(in_dims => hidden_dims, 0.01f0, 1.0f0, true),
Dense(hidden_dims => out_dims, sigmoid))
end
function (s::SpiralClassifier)(
x::AbstractArray{T, 3}, ps::NamedTuple, st::NamedTuple) where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_fastgru = s.fastgru_cell(x_init, ps.fastgru_cell, st.fastgru_cell)
for x in x_rest
(y, carry), st_fastgru = s.fastgru_cell((x, carry), ps.fastgru_cell, st_fastgru)
end
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
st = merge(st, (classifier=st_classifier, fastgru_cell = st_fastgru))
return vec(y), st
end
# ----- loss
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
# ----- training
function main(model_type)
dev = cpu_device()
# Get the dataloaders
train_loader, val_loader = get_dataloaders() .|> dev
# Create the model
model = model_type(2, 8, 1)
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model) |> dev
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
for epoch in 1:25
# Train the model
for (x, y) in train_loader
# x: (2,50,128), y: (128,) # dimension time trials
(_, loss, _, train_state) = Training.single_train_step!(
AutoZygote(), lossfn, (x, y), train_state)
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
end
# Validate the model
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model(x, train_state.parameters, st_)
loss = lossfn(ŷ, y)
acc = accuracy(ŷ, y)
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
end
end
return (train_state.parameters, train_state.states) |> cpu_device()
end
ps_trained, st_trained = main(SpiralClassifier)
When I try to transfer my self-defined GRUcell to the tutorial MNIST Classification using Neural ODEs, I don't know how to start the job. Really appreciate If anyone could help me! Thanks!
The question does not make much sense. Having hidden state which is carried over to the next call makes the equation not an ODE and thus not convergent. If you do what you do here where you init the hidden state on each call, this model is equivalent to just calling the NN that is supposed to be recurrant, and so you might as well call that NN directly. So I don't quite get what you're trying to do?
Sorry for confusing. I would like to train a sequence-to-sequence model where the RNN could first derive a series of values and they are then fed into a sequence-to-sequence neuralode stuff as the inhomogeneous equation input. The weight in RNN and parameters in neuralode are trained together.
Maybe the question can be simplified as "How can I train a sequence-to-sequence neuralode with a series of inputs ?"
Maybe @avik-pal has an example
The weight in RNN and parameters in neuralode are trained together.
Do you mean the RNN weights and the neural network weights are shared?
The weight in RNN and parameters in neuralode are trained together.
Do you mean the RNN weights and the neural network weights are shared?
No, I mean the output of RNN could be the input of neualode at each time point.
But without state? Then it's not an RNN?