[BUG] How to train a TCNModel with a custom dataset?
Describe the bug
By debugging the training process for the TCNModel class that leads to the forward() and _compute_loss() I found out that the _compute_loss() function reshapes the data using the squeeze operator from PyTorch, causing a mismatch between the target shape and the model output's shape. The incriminated line is here: https://github.com/unit8co/darts/blob/000d29d33f6888d079fa426731ff7ac9b1348f8d/darts/models/forecasting/pl_forecasting_module.py#L368
To Reproduce
At the moment I used a TCNModel instance with batch_size=64, input_chunk_length=64, output_chunk_length=1. I have used the fit_from_dataset() function with a custom instance of the PastCovariatesSequentialDataset that overrides the __getitem__ method so to return my own sliding windows. If you need to see the code for the dataset class just let me know.
Expected behavior I would expect the model output's (tensor) shape to match the target tensor's shape so to compute the loss correctly.
System (please complete the following information):
- Python version: 3.9.18
- darts version 0.28.0
Additional context
When the output object reaches that line of code, the shape of the tensor is (64, 64, 1, 1) (shape caused by the following line: https://github.com/unit8co/darts/blob/000d29d33f6888d079fa426731ff7ac9b1348f8d/darts/models/forecasting/tcn_model.py#L249) and the target is (64, 1, 1). What the squeeze operator does is transform the output object shape into (64, 64, 1) instead, causing the mismatch.
I also exploited Google's Gemini to help me reason about what was happening, this is the chat I had with it if it can be of any use: https://g.co/gemini/share/93e4bcaa9b3b
Hi @giacomoguiduzzi, could you provide a minimum reproducible example for this issue? Without that it's difficult to help.
Hi @dennisbader, I thought of opening the issue in the meantime just in case you had any idea of similar issues; I'm working on the example to provide you, I'll paste it here as soon as I have it. Thanks in advance!
Hello @dennisbader, I thought of creating a git repo as the script is pretty long, although I tried to minimise the code for the example: https://github.com/giacomoguiduzzi/tcnmodel_bug_example The script terminates correctly, but you can see the UserWarnings from PyTorch about the MSE computation in the output log. Looking forward to your kind response.
Hi @giacomoguiduzzi, and sorry for the late response. TCNModel is a special case, where the output of the model always has input_chunk_length points (see the training dataset configuration here). The reasoning behind it is described in #1965.
The future target returned by your training dataset should include input_chunk_length - output_chunk_length points from the end of the past_target, as well as output_chunk_length points from your future target.
future_target = np.concatenate([past_target[-(self.input_chunk_length - self.output_chunk_length):], future_target], axis=0)
P.s. if you're using Darts version 0.30.0 already: We now added the sample weights to our training datasets. This requires to return an additional value in the __getitem__ method.
You can simply add a line to the return
return (
past_target,
covariate,
static_covariate,
None, # <---- add this line for the sample weights
future_target
)
Hi @dennisbader, after reading the posts you suggested, it makes sense. It looks like the model runs fine now with your correction! Thanks a lot for the help.
Best Regards, Giacomo Guiduzzi