Ensemble-Pytorch icon indicating copy to clipboard operation
Ensemble-Pytorch copied to clipboard

How to script the forward pass?

Open francescamanni1989 opened this issue 4 years ago • 6 comments

Hu everyone,

I am trying to script the ensemble, however, argsvar cannot be used with torchscript

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File ".....\lib\site-packages\torchensemble\soft_gradient_boosting.py", line 390 "classifier_forward", ) def forward(self, *x): ~~ <--- HERE output = [estimator(*x) for estimator in self.estimators_] output = op.sum_with_multiplicative(output, self.shrinkage_rate)

do you have any idea on how to handle it?

francescamanni1989 avatar Sep 17 '21 15:09 francescamanni1989

Thanks for reporting! @francescamanni1989

Could you provide the code snippet that reproduces the runtime error.

xuyxu avatar Sep 18 '21 11:09 xuyxu

Hi,

the code part in gradient_boosting.py is in the argsvar part, when boosting is performed: def forward(self, *x): output = [estimator(*x) for estimator in self.estimators_] output = op.sum_with_multiplicative(output, self.shrinkage_rate) output = F.softmax(output, dim=1) return output

My error comes, when trying to script the model:

model = model_ensemble traced_model = torch.jit.script(model)

where model_ensemble could be:

model_ensemble = GradientBoostingClassifier( estimator=MLP, n_estimators=10, cuda=False, shrinkage_rate=0.9, )

francescamanni1989 avatar Sep 19 '21 17:09 francescamanni1989

It looks like the package does not support torchscript well for now. I will have a careful look when I get a moment, thanks!

xuyxu avatar Sep 20 '21 01:09 xuyxu

Exactly! Thank you

francescamanni1989 avatar Sep 20 '21 06:09 francescamanni1989

Also, the function sum is not scriptable, but this could be by-passed using @torch.jit.ignore()

francescamanni1989 avatar Sep 20 '21 08:09 francescamanni1989

My suggestion for the indexed variable is to use a for loop instead.

francescamanni1989 avatar Sep 20 '21 08:09 francescamanni1989