tabnet icon indicating copy to clipboard operation
tabnet copied to clipboard

Save TabNet in ONNX format

Open luigisaetta opened this issue 4 years ago • 20 comments
trafficstars

Feature request

Is it possible to save TabNet in ONNX format?

What is the expected behavior?

What is motivation or use case for adding/changing the behavior? ONNX is quickly becoming the de-facto standard to save models, even because these way you avoid to import packages when you want to pack for inference.

How should this be implemented in your opinion?

Are you willing to work on this yourself? well, for now don't have a precise idea, but willing to give some help if I have some suggestion where to start.

The feature probably could be implemented as a NotBook example, therefore with no needed changes to the core implementation.

luigisaetta avatar Mar 28 '21 07:03 luigisaetta

I'm not familiar with ONNX, but it would be quite easy to save the network as traced script (from pytorch jit: https://pytorch.org/docs/stable/jit.html), which could be used for inference without the need of pytorch-tabnet but also without python itself (can be called in C++ only).

Training the model and then saving the model.network in eval mode should work without problem. I think ONNX would work the same way. But this production ready requirements can be specific to each environment (some will have python inside docker with the library available for inference, some only C++, some will be interested in getting the explanations with the predictions while some will only care about predictions). So I feel we could only give some examples on how this would work, but since the network is accessible and is just a simple torch.nn.Module I feel that it is a bit beyond the scope of the library.

Feel free to open a PR giving examples for either ONNX or jit and I'll be happy to review (not sure about adding onnx as a dependency in the repo however), if you have questions I might help for jit but I guess it would look like something like this:

def save_torch_script(tab_model, X_infer_ex, saving_path, model_name):
    """
    Utility function to save tabnet model as torch script

    Parameters
    ----------
    - tab_model : pytorch-tabnet model 
        A trained network to save
    - X_infer_ex : EDIT torch.Tensor of size (B, D) and not 2D np.array (B, D)
        Batch containing B examples with D features 
    - model_name : str
        Name of the file to create, shoud not contain extension name
    - saving_path : str
        Path to save the file
    Returns
    -------
    - traced_script_module : torch.jit.trace
        Traced model that has been saved
    """
    tab_model.network.eval()
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(tab_model.network, X_infer_ex)
    traced_script_module.save(saving_path + model_name + ".pt")
    return traced_script_module

I think that's it! (this will only trace the forward, without explanation, you'll need to create a wrapper with a custom forward function to get both preds and explanations)

Optimox avatar Mar 28 '21 09:03 Optimox

Hi Optimo, thanks for your quick answer.

My main use case is to be able to pack the trained model in a REST service for predictions. In Python.

As far as I understand (I'm not a great expert of PyTorch) a TabNetClassifier is a torch.nn.Module, so as explained here:

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

we should be able to export through tracing, using torch.onnx.export

as soon as I have time, I'll follow, need to see what is the meaning of the params

I agree with you that adding onnx to TabNet is not what should be done, I was thinking to add an example of Notebooks and best practices.

luigisaetta avatar Mar 28 '21 13:03 luigisaetta

It seems more difficult than I expected. When I call:

torch.onnx.export(clf.network, dummy_input, "tabnet1", verbose=True)

I get the following error:

RuntimeError: Only tuples, lists, and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type numpy.ndarray

It seems:

  • cannot export if the input is NumPy array
  • but TabNetClassifier doesn't accept Torch tensor as input

What kind of data can I pass as input? Only NumPy array? Strange enough since under the cover it is PyTorch. Should be easy to accept Tensor

luigisaetta avatar Mar 29 '21 07:03 luigisaetta

Hi,

I will try to see how this can work. To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?

Hartorn avatar Mar 29 '21 07:03 Hartorn

Hi,

I will try to see how this can work. To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?

Only Python. I want to export the model trained using ONNX to be able to (for example) develop a REST service and avoid having to install pytorch-tabnet, using only ONNX runtime

luigisaetta avatar Mar 29 '21 07:03 luigisaetta

Hi, I will try to see how this can work. To be sure, are you trying to save & load ? in python ? or are you trying save in python and load in c++ ?

Only Python. I want to export the model trained using ONNX to be able to (for example) develop a REST service and avoid having to install pytorch-tabnet, using only ONNX runtime

For that, you can use pytorch save method, but I will come back with some tests using onnx

Hartorn avatar Mar 29 '21 09:03 Hartorn

@luigisaetta I think trying to ONNXify (whatever this is called) the entire class TabNetClassifier is doomed to fail - I'll be very surprised if you manage to export everything with ONNX or jit (@Hartorn I know you are full of surprise :) ). I think you should focus on exporting the network only, which is accessible with network() method.

TabNetClassfier does not take tensors as input, but the network does so it's weird.

Optimox avatar Mar 29 '21 09:03 Optimox

@Optimox You are right, only the network should be exported, but I have to check also how the input format is (only one input or several for the embeddings and so on)

I will see if we can setup some kind of optional deps to have a custom exporter, or at least to have a notebook concerning this I should manage to get something working, based on

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html# http://onnx.ai/sklearn-onnx/auto_examples/plot_custom_model.html#sphx-glr-auto-examples-plot-custom-model-py http://onnx.ai/sklearn-onnx/auto_examples/plot_custom_parser_alternative.html#sphx-glr-auto-examples-plot-custom-parser-alternative-py http://onnx.ai/sklearn-onnx/auto_examples/plot_pipeline_lightgbm.html#sphx-glr-auto-examples-plot-pipeline-lightgbm-py

Will try to have a look this week.

Hartorn avatar Mar 29 '21 09:03 Hartorn

@luigisaetta I think trying to ONNXify (whatever this is called) the entire class TabNetClassifier is doomed to fail - I'll be very surprised if you manage to export everything with ONNX or jit (@Hartorn I know you are full of surprise :) ). I think you should focus on exporting the network only, which is accessible with network() method.

TabNetClassfier does not take tensors as input, but the network does so it's weird.

Hi, as you can see from my comment above, I get error when I apply the torch.onnx.export to clf.network. So, as far as I understand I'm trying to export only the network. But, since it wants an input... it doesn't accept NumPy array.

luigisaetta avatar Mar 29 '21 11:03 luigisaetta

@luigisaetta sorry I wrote on the docstring example numpy arrray but it should be a torch.Tensor, does it work with tensors?

Optimox avatar Apr 02 '21 07:04 Optimox

@luigisaetta sorry I wrote on the docstring example numpy arrray but it should be a torch.Tensor, does it work with tensors?

No, it doesn't. I'll find my test and post here the error. Basically, I think it calls a method that exists in numpy array but not in Torch tensor.

luigisaetta avatar Apr 02 '21 11:04 luigisaetta

@Optimox I have published an article on TowardsDataScience, https://towardsdatascience.com/pytorch-tabnet-integration-with-mlflow-cb14f3920cb0 There, I'm talking about pytorch-tabnet and integration with MLflow.

luigisaetta avatar Apr 16 '21 13:04 luigisaetta

Hello @luigisaetta,

Great article, very detailed! I'm happy to see that integration with MLflow is made so easy by the callbacks. @queraq worked on this and I'm sure neither of us had this specific usage in mind.

There is one part of the article where I think a bit of clarification would be welcomed: the Encoder-Decoder part. In fact TabNet models are only sort of encoders (plus all the sequential attention part), there is no decoder at all. The only reason there exists a decoder part is to enable self-supervised pre-training, which needs a decoder. Since you do not mention pre-training in the article I think you should not talk about a decoder-encoder model, or maybe you could add a paragraph about TabNetPretrainer.

Anyway great article, thanks for sharing with us and giving credits to the repo.

But... the article does not tell me if you managed to get ONNX format working?! :)

Cheers!

Optimox avatar Apr 16 '21 14:04 Optimox

@Optimox regarding onnx, no I didn't make any progress. The point where I become blocked is that TabNet doesn't seem to accept tensors as input. I think in the code it calls some methods existing only for numpy array. Have you any suggestion?

luigisaetta avatar Apr 16 '21 14:04 luigisaetta

hmm actually I think I know.

If you have a look a this file https://github.com/dreamquark-ai/tabnet/blob/develop/pytorch_tabnet/tab_network.py where everything happens about the network, we are actually using numpy for some stupid reason (laziness and bad habits mainly).

I think it would be very easy to replace all the np.any_function in this code by the torch equivalent. This might have two positive effects:

  • allow ONNX format
  • probably speed up the code with GPU

I don't have much time at the moment but I'll definitely change that. If you want to make those changes and see if it works for ONNX don't hesitate. You can also open a PR and I'll review it carefully.

Otherwise I'll do this as soon as I can or maybe @eduardocarvp will have a look before me?

I think we might have found your problem :)

Optimox avatar Apr 16 '21 15:04 Optimox

Hi, has anyone made any progress on this yet? would be really appreciated if you could share a bit about how the export would work. Right now I am stuck at exporting and the error tells me this: "ONNX export failed: Couldn't export Python operator Entmax15Function"

Could anyone help?

rxbh2019 avatar Jul 25 '21 04:07 rxbh2019

I'm not familiar with ONNX, but it would be quite easy to save the network as traced script (from pytorch jit: https://pytorch.org/docs/stable/jit.html), which could be used for inference without the need of pytorch-tabnet but also without python itself (can be called in C++ only).

Training the model and then saving the model.network in eval mode should work without problem. I think ONNX would work the same way. But this production ready requirements can be specific to each environment (some will have python inside docker with the library available for inference, some only C++, some will be interested in getting the explanations with the predictions while some will only care about predictions). So I feel we could only give some examples on how this would work, but since the network is accessible and is just a simple torch.nn.Module I feel that it is a bit beyond the scope of the library.

Feel free to open a PR giving examples for either ONNX or jit and I'll be happy to review (not sure about adding onnx as a dependency in the repo however), if you have questions I might help for jit but I guess it would look like something like this:

def save_torch_script(tab_model, X_infer_ex, saving_path, model_name):
    """
    Utility function to save tabnet model as torch script

    Parameters
    ----------
    - tab_model : pytorch-tabnet model 
        A trained network to save
    - X_infer_ex : EDIT torch.Tensor of size (B, D) and not 2D np.array (B, D)
        Batch containing B examples with D features 
    - model_name : str
        Name of the file to create, shoud not contain extension name
    - saving_path : str
        Path to save the file
    Returns
    -------
    - traced_script_module : torch.jit.trace
        Traced model that has been saved
    """
    tab_model.network.eval()
    # Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
    traced_script_module = torch.jit.trace(tab_model.network, X_infer_ex)
    traced_script_module.save(saving_path + model_name + ".pt")
    return traced_script_module

I think that's it! (this will only trace the forward, without explanation, you'll need to create a wrapper with a custom forward function to get both preds and explanations)

Hi! I tried this but I get this error: Could not export Python function call 'SparsemaxFunction'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants: /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/sparsemax.py(109): forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(640): forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(160): forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(471): forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/pytorch_tabnet/tab_network.py(586): forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1118): _slow_forward /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/nn/modules/module.py(1130): _call_impl /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/jit/_trace.py(967): trace_module /local_disk0/.ephemeral_nfs/envs/pythonEnv-0c30c1b9-5719-40d6-a3a9-5d19c4c686f8/lib/python3.8/site-packages/torch/jit/_trace.py(750): trace (23): /databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3437): run_code /databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3357): run_ast_nodes /databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(3165): run_cell_async /databricks/python/lib/python3.8/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner /databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2940): _run_cell /databricks/python/lib/python3.8/site-packages/IPython/core/interactiveshell.py(2894): run_cell /databricks/python_shell/scripts/PythonShellImpl.py(757): run_cell /databricks/python_shell/scripts/PythonShellImpl.py(269): run /databricks/python_shell/scripts/PythonShellImpl.py(1234): launch_process /databricks/python_shell/scripts/PythonShell.py(29):

Is there any way to fix this?

mythicaa avatar Sep 30 '22 09:09 mythicaa

Yes you probably need to 'scriptify' sparsemax (and entmax functions) so that they can be accepted for tracing.

I don't know how hard it would be, you can try adding @script on top of the definition of sparsemax and entmax and see if it works.

Optimox avatar Sep 30 '22 09:09 Optimox

Thanks for the reply. Im trying to speed up inference and torchscript is one way I was trying. Is there any other more straightforward method you would suggest to speed it up before I try this for torchscript?

mythicaa avatar Sep 30 '22 09:09 mythicaa

Oh guys, is it still can not export to ONNX right now?

duanckham avatar Jun 20 '23 17:06 duanckham