Fix issues / Enhance code in LSTM AE
Fix deprecation issue in MSELoss (reduction is now the proper keyword, 'sum' is the mode it used to operate in so I preserved it here) Fix issue where squeezing x in lstm decoder would result in a 1D Tensor, thus returning an error on torch.mm Add gradient clipping Add device argument to quick_train (also goes for train_model and get_encodings)
Thanks for opening this. What testing did you do?
Good question.
Testing was mostly limited to my use case as I did not bother creating unit tests for this.
- Tested the LSTM AE with quick_train with a custom dataset containing variable length sequences (including length 1, which was problematic in the previous version of the code).
- Made sure that it works with device selection via the keyword (sending directly torch.device)
- Made sure the "deprecated loss" was exactly equal to the "new" loss (MSE reduction) during training
I'm not gonna lie: I know the delta in additions and deletions looks big, but most of that is just due to black's formatting which I forgot to remove before saving as well as premade functions from pytorch. I did not do extensive testing beyond my use case.
If you spotted something, let me know and I'll commit a fix.