darts icon indicating copy to clipboard operation
darts copied to clipboard

Add TSMixer model

Open alexcolpitts96 opened this issue 1 year ago • 11 comments

I recently found TSMixer (http://arxiv.org/abs/2303.06053).

It is very similar to TiDE (#1726) but with a few tweaks.

It should be pretty straight forward to implement based on the implementation of TiDE (#1727).

I will try to get started on it in the next few days.

alexcolpitts96 avatar May 31 '23 19:05 alexcolpitts96

Google Research implementation: https://github.com/google-research/google-research/blob/master/tsmixer/tsmixer_basic/models/tsmixer.py

Details in the paper aren't great; however, the source code clears things up.

alexcolpitts96 avatar Jun 23 '23 03:06 alexcolpitts96

@alexcolpitts96 did you get the paper implement tsmixer_extended? it seems support past/static/future covariable features

joshua-xia avatar Jul 11 '23 00:07 joshua-xia

I have ran into a few things with the implementation and had some other PRs that I needed to cleanup.

I managed to implement reversible instance normalization, but there is a bug in the tests that only happens during the build process within Github.

The rest of the model is pretty straightforward, I just need to find the time. I just started a new job so I am a little short on time as of late.

alexcolpitts96 avatar Jul 11 '23 00:07 alexcolpitts96

Recently Google published a paper and an article on TSMixer: https://blog.research.google/2023/09/tsmixer-all-mlp-architecture-for-time.html

@alexcolpitts96 do you have started with a pytorch implementation that can fit into darts?

meteoDaniel avatar Sep 15 '23 11:09 meteoDaniel

I started working on it roughly two months ago. I have been busy wrapping up school and starting a new job. I should have some time to clean it up over the next few weeks.

I managed to get the skeleton written, but I still need to add covariates and probabilistic forecasting.

https://github.com/alexcolpitts96/darts/blob/tsmixer/darts/models/forecasting/tsmixer_model.py

alexcolpitts96 avatar Sep 17 '23 04:09 alexcolpitts96

From my point of view that looks good. Why do you think you need probabilistic forecasting? Does TSMixer provide it by nature? Within tft , probabilistic forecast is a result of the quantile loss function. Maybe I am wrong but in case you want to add this feature to TSMixer, I think you just need to run it with QuantileLoss.

meteoDaniel avatar Sep 17 '23 16:09 meteoDaniel

@alexcolpitts96 Did you have any time to work on this further? Would be interested in using this model. Also open to contribute.

thijsjls avatar Nov 21 '23 14:11 thijsjls

IBM has released its version of the PatchTSMixer on HuggingFace. Maybe this helps to have it available in darts soon

StatMixedML avatar Jan 15 '24 15:01 StatMixedML

Pay attention to the fact that there are apparently 2 different models named "TSMixer":

  • TSMixer (by Google): https://arxiv.org/abs/2303.06053
  • TSMixer (by IBM), renamed to PatchTSMixer: https://arxiv.org/abs/2306.09364

candalfigomoro avatar Feb 14 '24 09:02 candalfigomoro

@alexcolpitts96 @meteoDaniel @thijsjls Hi everyone, I've looked into your code @alexcolpitts96 and it looks really good! I've tried it, using the following code, including lists of timeseries, covariates, encoders:

model_params = {
        "input_chunk_length": 240,  # hist_len
        # not tuned
        "use_static_covariates": False,
        "output_chunk_length": 37,  # pred_len
        "n_epochs": n_epochs,
    }
    

    model = TSMixerModel(
        **model_params,
        pl_trainer_kwargs={
          "accelerator": "auto",
          "devices":"auto"
        },
        add_encoders = {
          'datetime_attribute': {'past': ['hour', 'day_of_week', 'month'],'future': ['hour', 'day_of_week', 'month']},
          'transformer': Scaler(),
        },
        model_name = 'tsmixer',
        save_checkpoints=True,
        force_reset=True
    )
    
    model.fit(ts_train_scaled_list,
              future_covariates=cov_list,
              val_series = ts_val_scaled_list,
              val_future_covariates = cov_list,
              verbose=False)
    
    #load best model on validation set to avoid overfitting
    model = TSMixerModel.load_from_checkpoint(model_name = 'tsmixer', best=True)

and it works great! Only thing I had to change in you code is still a old import statement from skicit-learn, which is removed from the current darts version, so just merging with the newest darts version, should resolve it.

I would really appreciate it if you go forward and push this as I really would like to use it and results are so good from TSMixer. Thank you very much! I'm also very happy to help!

leoniewgnr avatar Mar 06 '24 01:03 leoniewgnr

I made a PR as the above seems to have gone stale. Any feedback is welcome!

cristof-r avatar Mar 21 '24 14:03 cristof-r