Lux.jl icon indicating copy to clipboard operation
Lux.jl copied to clipboard

inconsistent LSTM results in time series forecast between Flux.jl and Lux.jl

Open chooron opened this issue 10 months ago • 1 comments

Hello, I first built an LSTM using Flux.jl, and the prediction results were quite satisfactory. The LSTM written in Flux.jl and the training process are as follows:

  1. build LSTM model
mutable struct LSTMRegressor
    lstm_layer::Flux.Recur
    dense_layer::Dense
end

function LSTMRegressor(;input_dim::Int, hidden_size::Int, output_dim::Int)
    lstm = LSTM(input_dim => hidden_size)
    dense = Dense(hidden_size => output_dim)
    return LSTMRegressor(lstm, dense)
end

function (m::LSTMRegressor)(x)
    h_end = []
    Flux.reset!(m)
    for i in size(x, 1)
        h_end = m.lstm_layer(x[i, :, :])
    end
    y = m.dense_layer(h_end)
    return y
end

Flux.@functor LSTMRegressor
  1. model training
function train(loaders; model, epochs)
    model = model |> gpu
    opt_state = Flux.setup(Adam(1e-3), model)
    train_loader, val_loader = loaders
    for epoch in 1:epochs
        # train lstm model
        train_loss = []
        Flux.trainmode!(model)
        for (i, (x, y)) in enumerate(train_loader)
            x = x |> gpu
            y = y |> gpu
            val, grads = Flux.withgradient(model) do m
                y_hat = m(x)
                Flux.mse(y_hat, y)
            end
            push!(train_loss, val)
            Flux.update!(opt_state, model, grads[1])
        end
        Flux.testmode!(model)
        # validate lstm model
        val_loss = []
        for (i, (x, y)) in enumerate(val_loader)
            x = x |> gpu
            y = y |> gpu
            y_hat = model(x)
            loss = Flux.mse(y_hat, y)
            push!(val_loss, loss)
        end
        println("Epoch [$epoch]: mean train Loss $(mean(train_loss)/size(train_loss,1)), mean val Loss $(mean(val_loss)/size(val_loss,1))")
    end
    return model
end

Then I used Lux.jl to build the same LSTM model, but its prediction results were far inferior to the former. This is my Lux.jl code:

  1. build LSTM model
using Lux

# attribute difined
struct LSTMRegressor{L,C} <:
       Lux.AbstractExplicitContainerLayer{(:lstm_cell, :regressor)}
    lstm_cell::L
    regressor::C
end

# construct function
function LSTMRegressor(; in_dims, hidden_dims, out_dims)
    return LSTMRegressor(LSTMCell(in_dims => hidden_dims),
        Dense(hidden_dims => out_dims))
end

# forward function
function (m::LSTMRegressor)(x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple) where {T}
    x_init, x_rest = Iterators.peel(eachslice(x; dims=2))
    (y, carry), st_lstm = m.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    for x in x_rest
        (y, carry), st_lstm = m.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    y, st_regressor = m.regressor(y, ps.regressor, st.regressor)
    st = merge(st, (regressor=st_regressor, lstm_cell=st_lstm))
    return y, st
end
  1. model training
function mse_loss(x, y, model, ps, st)
    y_hat, st = model(x, ps, st)
    return sum(abs2, y_hat .- y), st
end
function create_optimizer(ps, lr)
    opt = Optimisers.ADAM(lr)
    return Optimisers.setup(opt, ps)
end
function train(loaders; model, ps, st, epochs, device)
    ps = ps |> device
    st = st |> device
    opt_state = create_optimizer(ps, 1e-3)
    train_loader, val_loader = loaders
    for epoch in 1:epochs
        # train lstm model
        train_loss = []
        for (x, y) in train_loader
            x = x |> device
            y = y |> device
            (loss, st), back = pullback(p -> mse_loss(x, y, model, p, st), ps)
            gs = back((one(loss), nothing, nothing))[1]
            opt_state, ps = Optimisers.update(opt_state, ps, gs)
            push!(train_loss, loss)
        end
        st_ = Lux.testmode(st)
        # validate lstm model
        val_loss = []
        for (x, y) in val_loader
            x = x |> device
            y = y |> device
            loss, st_ = mse_loss(x, y, model, ps, st_)
            push!(val_loss, loss)
        end
        println("Epoch [$epoch]: mean train Loss $(mean(train_loss)/size(train_loss,1)), mean val Loss $(mean(val_loss)/size(val_loss,1))")
    end
    return (ps, st) |> cpu_device()
end

I would like to ask if there is a problem with my Lux.jl code that caused this issue.

chooron avatar Oct 18 '23 14:10 chooron