EconML
EconML copied to clipboard
How to pass args to the fit function for nuisance models
There are models only accept certain arguments when calling to the fit function. One such example is the categorical_feature arguments in lightgbm. Econml doesn't seem to pass fit arguments to nuisance models (Cf: DRLearner.fit ). Is there a way to pass arguments to the fit functions of the nuisance models?
I use a wrapper class, like here
In your case, you'd inherit from the lightgbm sklearn-style class instead of AutoML (though FLAML is worth trying as well :) )
I use a wrapper class, like here
In your case, you'd inherit from the lightgbm sklearn-style class instead of AutoML (though FLAML is worth trying as well :) )
Thanks for sharing! I also tried the wrapper class approach. The problem of wrapper class is that some package like lightgbm would inspect the parameters in the constructor and throw out warning, something like "UserWarning: categorical_feature keyword has been found in params and will be ignored" (Cf: source code ). These warnings are false alerts. Still, it's a bit an annoying.
That issue doesn't seem to arise in my version of the wrapper, as best I can see - or can you elaborate how it does?
That issue doesn't seem to arise in my version of the wrapper, as best I can see - or can you elaborate how it does?
lightgbm checks a set of parameters (which include 'feature_name', 'categorical_feature', ect) . Detail logic can be found in the _lazy_init function here. You probably won't see the warning if you don't touch those parameters or don't use lightgbm.
I get that, but my version of the wrapper only passes to the parent's fit() method exactly what you put in fit_params (plus of course whatever the calling function passes), so if you only pass it the things that belong there, that warning should never be triggered, no?
I think using a wrapper is the best workaround for now, but we'll consider adding direct support - see https://github.com/microsoft/EconML/issues/543.
I get that, but my version of the wrapper only passes to the parent's fit() method exactly what you put in fit_params (plus of course whatever the calling function passes), so if you only pass it the things that belong there, that warning should never be triggered, no?
I think the problem is that the fit parameters are saved internally as instance attributes and that's what lightgbm check eventually (through sklearn.BaseEstimator.get_params())