sktime icon indicating copy to clipboard operation
sktime copied to clipboard

[ENH] A base template for DL esimators

Open AurumnPegasus opened this issue 2 years ago • 6 comments

Is your feature request related to a problem? Please describe. As discussed in dev-days-2022, adding a base template for DL estimators to ease users to add them.

AurumnPegasus avatar Jul 15 '22 15:07 AurumnPegasus

Related to that (As we discussed during your talk - thanks btw!): It seems necessary to extend the fit method of the base class (sktime/classification/deep_learning/base.py). We should introduce a dict to pass arguments to keras. I am happy to work on that after finishing my first dl classifier port (inceptiontime) from sktime-dl.

tobiasweede avatar Jul 15 '22 15:07 tobiasweede

Hm, design-wise, I would avoid adding additional arguments to fit where possible.

The deep learning algorithms should comply with BaseForecaster, BaseClassifier, and so on, and to comply with the "strategy pattern", the public functions always must have the same signature.

Otherwise estimators of the same type are not easily exchangeable.

The "usual" way would be to add them to the constructor, and patch them through. Examples of that can be found in the models from statsmodels and pmdarima that we have interfaced.

Which fit arguments are you considering to add?

fkiraly avatar Jul 17 '22 22:07 fkiraly

Hey @tobiasweede , So @fkiraly and I had a small discussion related to this during the daily standup and:

  1. I am assuming that you want to introduce a way to specify various different parameters that keras.fit takes into the dl classes
  2. Franz suggested that instead of having a dict of arguments in fit function, we could specify it in the __init__ of the DL class, so as to avoid having to change different class structures. I also agree with the idea, and think that we could ask for an arguments called, say. fit_args in the __init__ function, which we can use in the fit and pass it as kwargs of keras.fit

AurumnPegasus avatar Jul 19 '22 13:07 AurumnPegasus

@tobiasweede, I've assigned this to you and put it on the deep learning project board for tracking. Talk to you later!

fkiraly avatar Jul 22 '22 13:07 fkiraly

Hey @tobiasweede , any update on this?

AurumnPegasus avatar Jul 27 '22 09:07 AurumnPegasus

Note: I am excepting in tests the CNNNetwork from test_inheritance until this issue is resolved. Then, we should either register it as a base class or add a rule in the test for concrete classes that inherit directly from BaseObject (the latter is currently not allowed).

fkiraly avatar Jul 29 '22 09:07 fkiraly