tabnet
tabnet copied to clipboard
Save TabNet in ONNX format
Feature request
Is it possible to save TabNet in ONNX format?
What is the expected behavior?
What is motivation or use case for adding/changing the behavior? ONNX is quickly becoming the de-facto standard to save models, even because these way you avoid to import packages when you want to pack for inference.
How should this be implemented in your opinion?
Are you willing to work on this yourself? well, for now don't have a precise idea, but willing to give some help if I have some suggestion where to start.
The feature probably could be implemented as a NotBook example, therefore with no needed changes to the core implementation.
I'm not familiar with ONNX, but it would be quite easy to save the network as traced script (from pytorch jit: https://pytorch.org/docs/stable/jit.html), which could be used for inference without the need of pytorch-tabnet but also without python itself (can be called in C++ only).
Training the model and then saving the model.network in eval mode should work without problem. I think ONNX would work the same way. But this production ready requirements can be specific to each environment (some will have python inside docker with the library available for inference, some only C++, some will be interested in getting the explanations with the predictions while some will only care about predictions). So I feel we could only give some examples on how this would work, but since the network is accessible and is just a simple torch.nn.Module I feel that it is a bit beyond the scope of the library.
Feel free to open a PR giving examples for either ONNX or jit and I'll be happy to review (not sure about adding onnx as a dependency in the repo however), if you have questions I might help for jit but I guess it would look like something like this:
def save_torch_script(tab_model, X_infer_ex, saving_path, model_name):
"""
Utility function to save tabnet model as torch script
Parameters
----------
- tab_model : pytorch-tabnet model
A trained network to save
- X_infer_ex : EDIT torch.Tensor of size (B, D) and not 2D np.array (B, D)
Batch containing B examples with D features
- model_name : str
Name of the file to create, shoud not contain extension name
- saving_path : str
Path to save the file
Returns
-------
- traced_script_module : torch.jit.trace
Traced model that has been saved
"""
tab_model.network.eval()
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(tab_model.network, X_infer_ex)
traced_script_module.save(saving_path + model_name + ".pt")
return traced_script_module
I think that's it! (this will only trace the forward, without explanation, you'll need to create a wrapper with a custom forward function to get both preds and explanations)
Hi Optimo, thanks for your quick answer.
My main use case is to be able to pack the trained model in a REST service for predictions. In Python.
As far as I understand (I'm not a great expert of PyTorch) a TabNetClassifier is a torch.nn.Module, so as explained here:
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
we should be able to export through tracing, using torch.onnx.export
as soon as I have time, I'll follow, need to see what is the meaning of the params
I agree with you that adding onnx to TabNet is not what should be done, I was thinking to add an example of Notebooks and best practices.
It seems more difficult than I expected. When I call:
torch.onnx.export(clf.network, dummy_input, "tabnet1", verbose=True)
I get the following error:
RuntimeError: Only tuples, lists, and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type numpy.ndarray
It seems:
- cannot export if the input is NumPy array
- but TabNetClassifier doesn't accept Torch tensor as input
What kind of data can I pass as input? Only NumPy array? Strange enough since under the cover it is PyTorch. Should be easy to accept Tensor
Hi,
I will try to see how this can work. To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?
Hi,
I will try to see how this can work. To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?
Only Python. I want to export the model trained using ONNX to be able to (for example) develop a REST service and avoid having to install pytorch-tabnet, using only ONNX runtime
Hi, I will try to see how this can work. To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?
Only Python. I want to export the model trained using ONNX to be able to (for example) develop a REST service and avoid having to install pytorch-tabnet, using only ONNX runtime
For that, you can use pytorch save method, but I will come back with some tests using onnx
@luigisaetta I think trying to ONNXify (whatever this is called) the entire class TabNetClassifier is doomed to fail - I'll be very surprised if you manage to export everything with ONNX or jit (@Hartorn I know you are full of surprise :) ).
I think you should focus on exporting the network only, which is accessible with network() method.
TabNetClassfier does not take tensors as input, but the network does so it's weird.
@Optimox You are right, only the network should be exported, but I have to check also how the input format is (only one input or several for the embeddings and so on)
I will see if we can setup some kind of optional deps to have a custom exporter, or at least to have a notebook concerning this I should manage to get something working, based on
https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html# http://onnx.ai/sklearn-onnx/auto_examples/plot_custom_model.html#sphx-glr-auto-examples-plot-custom-model-py http://onnx.ai/sklearn-onnx/auto_examples/plot_custom_parser_alternative.html#sphx-glr-auto-examples-plot-custom-parser-alternative-py http://onnx.ai/sklearn-onnx/auto_examples/plot_pipeline_lightgbm.html#sphx-glr-auto-examples-plot-pipeline-lightgbm-py
Will try to have a look this week.
@luigisaetta I think trying to ONNXify (whatever this is called) the entire class
TabNetClassifieris doomed to fail - I'll be very surprised if you manage to export everything with ONNX or jit (@Hartorn I know you are full of surprise :) ). I think you should focus on exporting the network only, which is accessible withnetwork()method.
TabNetClassfierdoes not take tensors as input, but the network does so it's weird.
Hi, as you can see from my comment above, I get error when I apply the torch.onnx.export to clf.network. So, as far as I understand I'm trying to export only the network. But, since it wants an input... it doesn't accept NumPy array.
@luigisaetta sorry I wrote on the docstring example numpy arrray but it should be a torch.Tensor, does it work with tensors?
@luigisaetta sorry I wrote on the docstring example numpy arrray but it should be a torch.Tensor, does it work with tensors?
No, it doesn't. I'll find my test and post here the error. Basically, I think it calls a method that exists in numpy array but not in Torch tensor.
@Optimox I have published an article on TowardsDataScience, https://towardsdatascience.com/pytorch-tabnet-integration-with-mlflow-cb14f3920cb0 There, I'm talking about pytorch-tabnet and integration with MLflow.
Hello @luigisaetta,
Great article, very detailed! I'm happy to see that integration with MLflow is made so easy by the callbacks. @queraq worked on this and I'm sure neither of us had this specific usage in mind.
There is one part of the article where I think a bit of clarification would be welcomed: the Encoder-Decoder part. In fact TabNet models are only sort of encoders (plus all the sequential attention part), there is no decoder at all. The only reason there exists a decoder part is to enable self-supervised pre-training, which needs a decoder. Since you do not mention pre-training in the article I think you should not talk about a decoder-encoder model, or maybe you could add a paragraph about TabNetPretrainer.
Anyway great article, thanks for sharing with us and giving credits to the repo.
But... the article does not tell me if you managed to get ONNX format working?! :)
Cheers!
@Optimox regarding onnx, no I didn't make any progress. The point where I become blocked is that TabNet doesn't seem to accept tensors as input. I think in the code it calls some methods existing only for numpy array. Have you any suggestion?
hmm actually I think I know.
If you have a look a this file https://github.com/dreamquark-ai/tabnet/blob/develop/pytorch_tabnet/tab_network.py where everything happens about the network, we are actually using numpy for some stupid reason (laziness and bad habits mainly).
I think it would be very easy to replace all the np.any_function in this code by the torch equivalent.
This might have two positive effects:
- allow ONNX format
- probably speed up the code with GPU
I don't have much time at the moment but I'll definitely change that. If you want to make those changes and see if it works for ONNX don't hesitate. You can also open a PR and I'll review it carefully.
Otherwise I'll do this as soon as I can or maybe @eduardocarvp will have a look before me?
I think we might have found your problem :)
Hi, has anyone made any progress on this yet? would be really appreciated if you could share a bit about how the export would work. Right now I am stuck at exporting and the error tells me this: "ONNX export failed: Couldn't export Python operator Entmax15Function"
Could anyone help?
I'm not familiar with ONNX, but it would be quite easy to save the network as traced script (from pytorch jit: https://pytorch.org/docs/stable/jit.html), which could be used for inference without the need of pytorch-tabnet but also without python itself (can be called in C++ only).
Training the model and then saving the
model.networkinevalmode should work without problem. I think ONNX would work the same way. But this production ready requirements can be specific to each environment (some will have python inside docker with the library available for inference, some only C++, some will be interested in getting the explanations with the predictions while some will only care about predictions). So I feel we could only give some examples on how this would work, but since the network is accessible and is just a simpletorch.nn.ModuleI feel that it is a bit beyond the scope of the library.Feel free to open a PR giving examples for either ONNX or jit and I'll be happy to review (not sure about adding onnx as a dependency in the repo however), if you have questions I might help for jit but I guess it would look like something like this:
def save_torch_script(tab_model, X_infer_ex, saving_path, model_name): """ Utility function to save tabnet model as torch script Parameters ---------- - tab_model : pytorch-tabnet model A trained network to save - X_infer_ex : EDIT torch.Tensor of size (B, D) and not 2D np.array (B, D) Batch containing B examples with D features - model_name : str Name of the file to create, shoud not contain extension name - saving_path : str Path to save the file Returns ------- - traced_script_module : torch.jit.trace Traced model that has been saved """ tab_model.network.eval() # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing. traced_script_module = torch.jit.trace(tab_model.network, X_infer_ex) traced_script_module.save(saving_path + model_name + ".pt") return traced_script_moduleI think that's it! (this will only trace the forward, without explanation, you'll need to create a wrapper with a custom forward function to get both preds and explanations)
Hi! I tried this but I get this error:
Could not export Python function call 'SparsemaxFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants:
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py(109): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(640): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(160): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(471): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(586): forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/jit/_trace.py(967): trace_module
/local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/jit/_trace.py(750): trace
Is there any way to fix this?
Yes you probably need to 'scriptify' sparsemax (and entmax functions) so that they can be accepted for tracing.
I don't know how hard it would be, you can try adding @script on top of the definition of sparsemax and entmax and see if it works.
Thanks for the reply. Im trying to speed up inference and torchscript is one way I was trying. Is there any other more straightforward method you would suggest to speed it up before I try this for torchscript?
Oh guys, is it still can not export to ONNX right now?