burn icon indicating copy to clipboard operation
burn copied to clipboard

LSTM Timeseries prediction example

Open NicoZweifel opened this issue 11 months ago • 9 comments

Checklist

  • [x] Confirmed that run-checks all script has been executed.
  • [x] Made sure the book is up to date with changes in this PR.

Changes

  • Adds a timeseries forecasting example using the LSTM that was added in https://github.com/tracel-ai/burn/pull/370, using a Partial Dataset from Huggingface.
  • The Dataset is limited to 10000 entries at the moment. Training on the full Dataset seems to be buggy still. I am not sure if the normalization is messed up or if there might be a memory limitation or bug with the SqliteDataset.

I have not narrowed it down yet as I am using custom Datasets on my other burn project and they work fine (InMemory with data from alphavantage). I might need to spend some more time on it to figure it out but since it doesn't block me in my other goals and the example seems to work with 10000 entries I though I could publish this as a draft for now.

Testing

cargo run --example lstm --features tch-cpu

NicoZweifel avatar Mar 26 '24 13:03 NicoZweifel

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 86.31%. Comparing base (7705fd9) to head (9dedb3f). Report is 59 commits behind head on main.

:exclamation: Current head 9dedb3f differs from pull request most recent head 8a8b57f

Please upload reports for the commit 8a8b57f to get more accurate results.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1532      +/-   ##
==========================================
- Coverage   86.38%   86.31%   -0.07%     
==========================================
  Files         693      683      -10     
  Lines       80473    78091    -2382     
==========================================
- Hits        69519    67408    -2111     
+ Misses      10954    10683     -271     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Mar 26 '24 15:03 codecov[bot]

@nathanielsimard Hello, I am always interested in the implementation of lstm in burn. I still think lstm is buggy right now. If a linear layer is added after the lstm, the parameters of the lstm and all layers before it will not be updated during training. I've been stuck on this problem for a long time.

The example of using lstm in this PR further confirms that lstm does have problems. I add some code in training.rs to check the parameters of the model before and after training.

let pjr = PrettyJsonFileRecorder::<FullPrecisionSettings>::new();
model.input_layer.clone().save_file("./input-before.json", &pjr).unwrap();
model.lstm.clone().save_file("./lstm-before.json", &pjr).unwrap();
model.output_layer.clone().save_file("./output-before.json", &pjr).unwrap();

// ......

model_trained.input_layer.clone().save_file("./input-after.json", &pjr).unwrap();
model_trained.lstm.clone().save_file("./lstm-after.json", &pjr).unwrap();
model_trained.output_layer.clone().save_file("./output-after.json", &pjr).unwrap();

After training, only the parameters of the output_layer changed. Nevertheless, for the dataset in the example, only one linear layer might be enough to overfit.

wcshds avatar Mar 28 '24 14:03 wcshds

@nathanielsimard Hello, I am always interested in the implementation of lstm in burn.

I was hoping I could spark the development of the LSTM implementation a bit with an example. I would love to use Burn for this purpose as well.

After training, only the parameters of the output_layer changed. Nevertheless, for the dataset in the example, only one linear layer might be enough to overfit.

Happy to incorporate your suggestions! Feel free to create a PR that makes changes to this branch.

NicoZweifel avatar Mar 28 '24 15:03 NicoZweifel

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.

We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

nathanielsimard avatar Mar 28 '24 17:03 nathanielsimard

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.

We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard do we have a separate ticket of "planned fix"? It would go to track and link it here.

antimora avatar Mar 28 '24 18:03 antimora

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation.

We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard Kinda off topic but it would be cool to have a generic TimeSeriesDataset that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.

NicoZweifel avatar Mar 28 '24 18:03 NicoZweifel

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation. We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard Kinda off topic but it would be cool to have a generic TimeSeriesDataset that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.

@NicoZweifel, that would be a great addition. You can file an issue for this and we can assign it to you.

antimora avatar Mar 28 '24 18:03 antimora

@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state, but the graph is actually held by the gate_state, which explains the problem when working with the current LSTM implementation. We already want to implement a client/server architecture in burn-autodiff to avoid graph merging, locking and to fix that problem.

@nathanielsimard Kinda off topic but it would be cool to have a generic TimeSeriesDataset that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.

@NicoZweifel, that would be a great addition. You can file an issue for this and we can assign it to you.

Thanks, I created a separate issue to discuss the details 👍

NicoZweifel avatar Mar 28 '24 21:03 NicoZweifel

This PR has been marked as stale because it has not been updated for over a month

github-actions[bot] avatar May 19 '24 12:05 github-actions[bot]

@NicoZweifel Hey 👋 I'm going through opened issues/PRs right now, looks like there hasn't been a lot of activity here for a while.

I'll close the draft PR but if you want to take it up eventually and need a review feel free to reopen and ping us 🙏

laggui avatar Sep 13 '24 17:09 laggui

@laggui Thanks will do 👍. I've made some progress in the meantime on my local fork but I need to finish a separate feature PR, as well as update/maintain this one before re-opening since it also contains some changes to window.rs (split into iterator + Dataset because of the lifetime).

I'd love to reopen and finish this eventually.

NicoZweifel avatar Oct 06 '24 20:10 NicoZweifel