Mark Kuiack
Mark Kuiack
Manually setting the `embedding_sizes` when initialising the model with `.from_dataset` solved the size mismatch issue, showing that this is the cause of the bug. ie. ``` # embedding_sizes = {"category_column":...
@Marcrb2 I check the embedding size on each rank and then broadcast the max embedding size required via a CPU process. This is quite a hacky solution but it works...
@fkiraly Do you know why the behavior of parallel process creation distribution and destruction would have changed so much with older pyTorch-forecasting, pyTorch, Lightning?
@Marcrb2 ```python import torch.distributed as dist class GlooGroupManager: """Manager for the distributed gloo process group.""" def __init__(self): """Initialize a new GlooGroupManager instance with no process group.""" self._process_group = None def...
@fkiraly Do you know why the old version of pytorch-forecasting worked without issue, then only after upgrading I started seeing these errors and had to develop the work-around above?
I've just encountered this issue after trying to update python, torch, pytorch-forecasting, and pytorch-lightning. It was previously working fine to train on GPU (EC2) and then load the model and...
Here is another solution. In this case you initialise the TFT model from the inference feature dataset, then you replace the internals of the model with the trained model parameters,...
Hi @mrgreen3325 `tft.pkl` is a pickle object not a checkpoint. Checkpoints are `.ckpt` files automatically created during training. by defaul they're in a `lightning_logs` directory and have names like `epoch=0-step=768.ckpt`
I would second this and also ask to be able to send additional information to the loss function. Something like a flag whether a value is censured on the upside...
The issue is that `plot_prediction_actual_by_variable` assumes `self.hparams.embedding_labels` to be in the same order as the values in `averages_actual_cat` but that's not the case. here: https://github.com/sktime/pytorch-forecasting/blob/1a2d83c7a5e6769c13164eeae7f447002f61f254/pytorch_forecasting/models/base/_base_model.py#L2280 Does anyone know why this...