pytorch-forecasting icon indicating copy to clipboard operation
pytorch-forecasting copied to clipboard

Model.predict concatenates batches for ground truth along incorrect axis

Open tRosenflanz opened this issue 2 years ago • 0 comments

When validation dataloader has more than one batch in it, the output.y of model.predict has an incorrect shape due to the incorrect concatenation axis.

E.g. if I have 200 items with max_prediction_length=4 and dataloader has batch size=100 I would expect y attribute of the output to have shape (200,4). Currently it outputs (100,8) which doesn't align with the output attribute (which correctly gives (200,4)

The error is in the use of dim=1 in the utility concat function. Should be dim=0 https://github.com/jdb78/pytorch-forecasting/blob/cf0f2dd8afc7cc4601cc1309ff4e03f6ecf2efd7/pytorch_forecasting/models/base_model.py#L317C13-L317C13

tRosenflanz avatar Dec 11 '23 23:12 tRosenflanz