Lux.jl
Lux.jl copied to clipboard
inconsistent LSTM results in time series forecast between Flux.jl and Lux.jl
Hello, I first built an LSTM using Flux.jl, and the prediction results were quite satisfactory. The LSTM written in Flux.jl and the training process are as follows:
- build LSTM model
mutable struct LSTMRegressor
lstm_layer::Flux.Recur
dense_layer::Dense
end
function LSTMRegressor(;input_dim::Int, hidden_size::Int, output_dim::Int)
lstm = LSTM(input_dim => hidden_size)
dense = Dense(hidden_size => output_dim)
return LSTMRegressor(lstm, dense)
end
function (m::LSTMRegressor)(x)
h_end = []
Flux.reset!(m)
for i in size(x, 1)
h_end = m.lstm_layer(x[i, :, :])
end
y = m.dense_layer(h_end)
return y
end
Flux.@functor LSTMRegressor
- model training
function train(loaders; model, epochs)
model = model |> gpu
opt_state = Flux.setup(Adam(1e-3), model)
train_loader, val_loader = loaders
for epoch in 1:epochs
# train lstm model
train_loss = []
Flux.trainmode!(model)
for (i, (x, y)) in enumerate(train_loader)
x = x |> gpu
y = y |> gpu
val, grads = Flux.withgradient(model) do m
y_hat = m(x)
Flux.mse(y_hat, y)
end
push!(train_loss, val)
Flux.update!(opt_state, model, grads[1])
end
Flux.testmode!(model)
# validate lstm model
val_loss = []
for (i, (x, y)) in enumerate(val_loader)
x = x |> gpu
y = y |> gpu
y_hat = model(x)
loss = Flux.mse(y_hat, y)
push!(val_loss, loss)
end
println("Epoch [$epoch]: mean train Loss $(mean(train_loss)/size(train_loss,1)), mean val Loss $(mean(val_loss)/size(val_loss,1))")
end
return model
end
Then I used Lux.jl to build the same LSTM model, but its prediction results were far inferior to the former. This is my Lux.jl code:
- build LSTM model
using Lux
# attribute difined
struct LSTMRegressor{L,C} <:
Lux.AbstractExplicitContainerLayer{(:lstm_cell, :regressor)}
lstm_cell::L
regressor::C
end
# construct function
function LSTMRegressor(; in_dims, hidden_dims, out_dims)
return LSTMRegressor(LSTMCell(in_dims => hidden_dims),
Dense(hidden_dims => out_dims))
end
# forward function
function (m::LSTMRegressor)(x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple) where {T}
x_init, x_rest = Iterators.peel(eachslice(x; dims=2))
(y, carry), st_lstm = m.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
for x in x_rest
(y, carry), st_lstm = m.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
y, st_regressor = m.regressor(y, ps.regressor, st.regressor)
st = merge(st, (regressor=st_regressor, lstm_cell=st_lstm))
return y, st
end
- model training
function mse_loss(x, y, model, ps, st)
y_hat, st = model(x, ps, st)
return sum(abs2, y_hat .- y), st
end
function create_optimizer(ps, lr)
opt = Optimisers.ADAM(lr)
return Optimisers.setup(opt, ps)
end
function train(loaders; model, ps, st, epochs, device)
ps = ps |> device
st = st |> device
opt_state = create_optimizer(ps, 1e-3)
train_loader, val_loader = loaders
for epoch in 1:epochs
# train lstm model
train_loss = []
for (x, y) in train_loader
x = x |> device
y = y |> device
(loss, st), back = pullback(p -> mse_loss(x, y, model, p, st), ps)
gs = back((one(loss), nothing, nothing))[1]
opt_state, ps = Optimisers.update(opt_state, ps, gs)
push!(train_loss, loss)
end
st_ = Lux.testmode(st)
# validate lstm model
val_loss = []
for (x, y) in val_loader
x = x |> device
y = y |> device
loss, st_ = mse_loss(x, y, model, ps, st_)
push!(val_loss, loss)
end
println("Epoch [$epoch]: mean train Loss $(mean(train_loss)/size(train_loss,1)), mean val Loss $(mean(val_loss)/size(val_loss,1))")
end
return (ps, st) |> cpu_device()
end
I would like to ask if there is a problem with my Lux.jl code that caused this issue.