darts
darts copied to clipboard
Add TSMixer model
I recently found TSMixer (http://arxiv.org/abs/2303.06053).
It is very similar to TiDE (#1726) but with a few tweaks.
It should be pretty straight forward to implement based on the implementation of TiDE (#1727).
I will try to get started on it in the next few days.
Google Research implementation: https://github.com/google-research/google-research/blob/master/tsmixer/tsmixer_basic/models/tsmixer.py
Details in the paper aren't great; however, the source code clears things up.
@alexcolpitts96 did you get the paper implement tsmixer_extended? it seems support past/static/future covariable features
I have ran into a few things with the implementation and had some other PRs that I needed to cleanup.
I managed to implement reversible instance normalization, but there is a bug in the tests that only happens during the build process within Github.
The rest of the model is pretty straightforward, I just need to find the time. I just started a new job so I am a little short on time as of late.
Recently Google published a paper and an article on TSMixer: https://blog.research.google/2023/09/tsmixer-all-mlp-architecture-for-time.html
@alexcolpitts96 do you have started with a pytorch implementation that can fit into darts?
I started working on it roughly two months ago. I have been busy wrapping up school and starting a new job. I should have some time to clean it up over the next few weeks.
I managed to get the skeleton written, but I still need to add covariates and probabilistic forecasting.
https://github.com/alexcolpitts96/darts/blob/tsmixer/darts/models/forecasting/tsmixer_model.py
From my point of view that looks good. Why do you think you need probabilistic forecasting? Does TSMixer provide it by nature? Within tft , probabilistic forecast is a result of the quantile loss function. Maybe I am wrong but in case you want to add this feature to TSMixer, I think you just need to run it with QuantileLoss.
@alexcolpitts96 Did you have any time to work on this further? Would be interested in using this model. Also open to contribute.
IBM has released its version of the PatchTSMixer on HuggingFace. Maybe this helps to have it available in darts soon
Pay attention to the fact that there are apparently 2 different models named "TSMixer":
- TSMixer (by Google): https://arxiv.org/abs/2303.06053
- TSMixer (by IBM), renamed to PatchTSMixer: https://arxiv.org/abs/2306.09364
@alexcolpitts96 @meteoDaniel @thijsjls Hi everyone, I've looked into your code @alexcolpitts96 and it looks really good! I've tried it, using the following code, including lists of timeseries, covariates, encoders:
model_params = {
"input_chunk_length": 240, # hist_len
# not tuned
"use_static_covariates": False,
"output_chunk_length": 37, # pred_len
"n_epochs": n_epochs,
}
model = TSMixerModel(
**model_params,
pl_trainer_kwargs={
"accelerator": "auto",
"devices":"auto"
},
add_encoders = {
'datetime_attribute': {'past': ['hour', 'day_of_week', 'month'],'future': ['hour', 'day_of_week', 'month']},
'transformer': Scaler(),
},
model_name = 'tsmixer',
save_checkpoints=True,
force_reset=True
)
model.fit(ts_train_scaled_list,
future_covariates=cov_list,
val_series = ts_val_scaled_list,
val_future_covariates = cov_list,
verbose=False)
#load best model on validation set to avoid overfitting
model = TSMixerModel.load_from_checkpoint(model_name = 'tsmixer', best=True)
and it works great! Only thing I had to change in you code is still a old import statement from skicit-learn, which is removed from the current darts version, so just merging with the newest darts version, should resolve it.
I would really appreciate it if you go forward and push this as I really would like to use it and results are so good from TSMixer. Thank you very much! I'm also very happy to help!
I made a PR as the above seems to have gone stale. Any feedback is welcome!