feat(skorch): add an inherited class from skorch.NeuralNet that is compatible with PyTorch Frame
Closes #147
~~@MacOS Please continue from here if it helps. Sorry for being so loud, but this took me a whole day, so I would appreciate it very much if you could make me as a co-author if you used this code.~~
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 93.52%. Comparing base (
ee98b87) to head (aa5484d). Report is 6 commits behind head on master.
:exclamation: Current head aa5484d differs from pull request most recent head cb76e8d. Consider uploading reports for the commit cb76e8d to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## master #375 +/- ##
=======================================
Coverage 93.52% 93.52%
=======================================
Files 124 124
Lines 6456 6456
=======================================
Hits 6038 6038
Misses 418 418
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Sure - on both!
@weihua916 Would you mind reviewing if you think this is a good way to implement it?
Also, it is strange that mypy in pre-commit does not raise errors, but mypy in CI does. I don't think there is any way to deal with this.
@weihua916 @zechengz @yiweny Would appreciate your review, thank you very much in advance.
A kind check-in. Is there any progress here?
No progress, sorry
- First,
torch_frame.DataSetcannot properly handleDataFranewith indices not [0,1,...].dataset.pyL736 must be modified toindices = np.where(self.df[self.split_col] == SPLIT_TO_NUM[split])[0]. - ~~Second, though I am not aware of this before,
torch_framemodels requirecol_stats,col_names_dictin the constructor and unfortunately do not allow modifications after instantiation. Furthermore, they do not follow the standards of scikit-learn as they do not save the constructor parameters, making it impossible to instantiate the module again.~~ Resolved by allowing to pass a function to create a module
Almost done
@weihua916 Thanks for your patience, I've added tests.
@weihua916 Removed all tutorials and added examples/sklearn_api.py instead.
The failing tests are probably due to a pandas version update and are not related to this PR.
@weihua916 Would you please reconsider merging this PR?
This change makes it easy to try out torch-frame based neural networks on code that already uses scikit-learn. (which was demonstrated in https://github.com/pyg-team/pytorch-frame/pull/375#discussion_r1673507323).
(This PR is not intended to save training of Pytorch models that do not use scikit-learn. lightning or fastai should be used for such applications.)
Thank you in advance.
@weihua916 @yiweny @akihironitta Any chance that this PR could be merged?
Thanks for your positive reply. In that case I would consider trying to make another package, but I am very busy right now, so please leave this as it is for a while. Note that in any case one-line change I made (in dataset.py) needs to be merged for that to work and you may want to review that specific line.