pytorch-forecasting
pytorch-forecasting copied to clipboard
Model.predict concatenates batches for ground truth along incorrect axis
When validation dataloader has more than one batch in it, the output.y of model.predict has an incorrect shape due to the incorrect concatenation axis.
E.g. if I have 200 items with max_prediction_length=4 and dataloader has batch size=100 I would expect y attribute of the output to have shape (200,4). Currently it outputs (100,8) which doesn't align with the output attribute (which correctly gives (200,4)
The error is in the use of dim=1 in the utility concat function. Should be dim=0 https://github.com/jdb78/pytorch-forecasting/blob/cf0f2dd8afc7cc4601cc1309ff4e03f6ecf2efd7/pytorch_forecasting/models/base_model.py#L317C13-L317C13