mia
mia copied to clipboard
PyTorch support documentation and examples
Hi, you mention in the readme that the package supports PyTorch models, but in ShadowModelBundle._fit
you assume the model has fit
method (line 116).
How exactly have you tested the PyTorch models? I was thinking of maybe using pytorch-fitmodule or SuperModule, but if there's a way you recommend already that would be great. Also it would be nice to include an example of how to load PyTorch modules in the package! (maybe I can do a PR after I'm able to do it myself :-)
Hi, ShadowModelBundle and AttackModelBundle take a scikit-like object, with fit
, predict
, and predict_proba
methods. You can use skorch to wrap your torch model in such an API, or mia's own mia.wrappers.TorchWrapper
. You can see example tests that use skorch (shadow, attack, serializers).
If you can add an example, that would be great!
Ah great I haven't seen the TorchWrapper
! In the end I just wrote my own class, but indeed using skorch
should be better, I'll do that instead, thanks.