tabnet icon indicating copy to clipboard operation
tabnet copied to clipboard

Can I use pytorch-like training process instead of scikit-compatible way of use this lib?

Open kaiwang0112006 opened this issue 1 year ago • 3 comments

I want to do research on tabnet with federated learning, which means I need to get the model weight out and set it back during each epoch of training. It would be easier with a pytorch-like epoch training process of using this lib instead of scikit-compatible way of training.

kaiwang0112006 avatar Aug 22 '24 01:08 kaiwang0112006

Hello, I am not sure that I understand your request. But if you want to use the tabnet network simply as a pytorch module and insert it inside your own pipeline you can simply used the modules from here: https://github.com/dreamquark-ai/tabnet/blob/2c0c4ebd2bb1cb639ea94ab4b11823bc49265588/pytorch_tabnet/tab_network.py#L508

Optimox avatar Aug 22 '24 09:08 Optimox

That;s great! What should I give to the parameter "group_attention_matrix" ?

kaiwang0112006 avatar Aug 23 '24 03:08 kaiwang0112006

This is an advanced feature, you can leave to None, otherwise you'll need to dig a bit into the code to use it. It's just a matrix of weights on how the attention can work across different features.

Optimox avatar Aug 23 '24 08:08 Optimox