burn
burn copied to clipboard
LSTM Timeseries prediction example
Checklist
- [x] Confirmed that
run-checks all
script has been executed. - [x] Made sure the book is up to date with changes in this PR.
Changes
- Adds a timeseries forecasting example using the
LSTM
that was added in https://github.com/tracel-ai/burn/pull/370, using a Partial Dataset from Huggingface. - The Dataset is limited to 10000 entries at the moment. Training on the full Dataset seems to be buggy still. I am not sure if the normalization is messed up or if there might be a memory limitation or bug with the
SqliteDataset
.
I have not narrowed it down yet as I am using custom Datasets on my other burn project and they work fine (InMemory
with data from alphavantage). I might need to spend some more time on it to figure it out but since it doesn't block me in my other goals and the example seems to work with 10000 entries I though I could publish this as a draft for now.
Testing
cargo run --example lstm --features tch-cpu
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 86.31%. Comparing base (
7705fd9
) to head (9dedb3f
). Report is 59 commits behind head on main.
:exclamation: Current head 9dedb3f differs from pull request most recent head 8a8b57f
Please upload reports for the commit 8a8b57f to get more accurate results.
Additional details and impacted files
@@ Coverage Diff @@
## main #1532 +/- ##
==========================================
- Coverage 86.38% 86.31% -0.07%
==========================================
Files 693 683 -10
Lines 80473 78091 -2382
==========================================
- Hits 69519 67408 -2111
+ Misses 10954 10683 -271
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
@nathanielsimard Hello, I am always interested in the implementation of lstm in burn. I still think lstm is buggy right now. If a linear layer is added after the lstm, the parameters of the lstm and all layers before it will not be updated during training. I've been stuck on this problem for a long time.
The example of using lstm in this PR further confirms that lstm does have problems. I add some code in training.rs
to check the parameters of the model before and after training.
let pjr = PrettyJsonFileRecorder::<FullPrecisionSettings>::new();
model.input_layer.clone().save_file("./input-before.json", &pjr).unwrap();
model.lstm.clone().save_file("./lstm-before.json", &pjr).unwrap();
model.output_layer.clone().save_file("./output-before.json", &pjr).unwrap();
// ......
model_trained.input_layer.clone().save_file("./input-after.json", &pjr).unwrap();
model_trained.lstm.clone().save_file("./lstm-after.json", &pjr).unwrap();
model_trained.output_layer.clone().save_file("./output-after.json", &pjr).unwrap();
After training, only the parameters of the output_layer changed. Nevertheless, for the dataset in the example, only one linear layer might be enough to overfit.
@nathanielsimard Hello, I am always interested in the implementation of lstm in burn.
I was hoping I could spark the development of the LSTM implementation a bit with an example. I would love to use Burn for this purpose as well.
After training, only the parameters of the output_layer changed. Nevertheless, for the dataset in the example, only one linear layer might be enough to overfit.
Happy to incorporate your suggestions! Feel free to create a PR that makes changes to this branch.
@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the hidden_state
, but the graph is actually held by the gate_state
, which explains the problem when working with the current LSTM implementation.
We already want to implement a client/server architecture in burn-autodiff
to avoid graph merging, locking and to fix that problem.
@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the
hidden_state
, but the graph is actually held by thegate_state
, which explains the problem when working with the current LSTM implementation.We already want to implement a client/server architecture in
burn-autodiff
to avoid graph merging, locking and to fix that problem.
@nathanielsimard do we have a separate ticket of "planned fix"? It would go to track and link it here.
@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the
hidden_state
, but the graph is actually held by thegate_state
, which explains the problem when working with the current LSTM implementation.We already want to implement a client/server architecture in
burn-autodiff
to avoid graph merging, locking and to fix that problem.
@nathanielsimard Kinda off topic but it would be cool to have a generic TimeSeriesDataset
that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.
@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the
hidden_state
, but the graph is actually held by thegate_state
, which explains the problem when working with the current LSTM implementation. We already want to implement a client/server architecture inburn-autodiff
to avoid graph merging, locking and to fix that problem.@nathanielsimard Kinda off topic but it would be cool to have a generic
TimeSeriesDataset
that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.
@NicoZweifel, that would be a great addition. You can file an issue for this and we can assign it to you.
@wcshds @NicoZweifel I have identified the issue and we already have a planned fix. However, we will prioritize it as it directly affects a real-world use case. The problem lies in the autodiff graph, which is always attached to a tensor. When two tensors with different graphs interact, we merge the graphs. However, this process assumes that all nodes in the graph will eventually interact, which is not the case for LSTM. For instance, you may only use the
hidden_state
, but the graph is actually held by thegate_state
, which explains the problem when working with the current LSTM implementation. We already want to implement a client/server architecture inburn-autodiff
to avoid graph merging, locking and to fix that problem.@nathanielsimard Kinda off topic but it would be cool to have a generic
TimeSeriesDataset
that supports windowing, similar to what other libraries have. If this is something that is desired I could try to look into it in a separate Issue/PR.@NicoZweifel, that would be a great addition. You can file an issue for this and we can assign it to you.
Thanks, I created a separate issue to discuss the details 👍
This PR has been marked as stale because it has not been updated for over a month
@NicoZweifel Hey 👋 I'm going through opened issues/PRs right now, looks like there hasn't been a lot of activity here for a while.
I'll close the draft PR but if you want to take it up eventually and need a review feel free to reopen and ping us 🙏
@laggui Thanks will do 👍. I've made some progress in the meantime on my local fork but I need to finish a separate feature PR, as well as update/maintain this one before re-opening since it also contains some changes to window.rs
(split into iterator + Dataset because of the lifetime).
I'd love to reopen and finish this eventually.