tabnet icon indicating copy to clipboard operation
tabnet copied to clipboard

Trouble with learning when using device = "auto" configuration

Open ruddnr opened this issue 1 year ago • 1 comments

Hi. I'm trying to fit a model but when I use default "auto" configuration, loss dose not decrease after 10 epochs. I'm using Windows and my cuda runtime version is 11.3.0.

tabnet_fit(Sale_Price ~. , data = ames,
                  config = tabnet_config(learn_rate = 0.02, device = "auto", verbose = TRUE, epoch = 30))
[Epoch 001] Loss: 39066898432.000000[Epoch 002] Loss: 39066730496.000000[Epoch 003] Loss: 39066587136.000000[Epoch 004] Loss: 39066431488.000000[Epoch 005] Loss: 39066255360.000000[Epoch 006] Loss: 39066083328.000000[Epoch 007] Loss: 39065911296.000000[Epoch 008] Loss: 39065739264.000000[Epoch 009] Loss: 39065571328.000000[Epoch 010] Loss: 39065399296.000000[Epoch 011] Loss: 39065178112.000000[Epoch 012] Loss: 39065178112.000000[Epoch 013] Loss: 39065178112.000000[Epoch 014] Loss: 39065178112.000000[Epoch 015] Loss: 39065178112.000000[Epoch 016] Loss: 39065178112.000000[Epoch 017] Loss: 39065178112.000000[Epoch 018] Loss: 39065178112.000000[Epoch 019] Loss: 39065178112.000000[Epoch 020] Loss: 39065178112.000000[Epoch 021] Loss: 39065178112.000000[Epoch 022] Loss: 39065178112.000000[Epoch 023] Loss: 39065178112.000000[Epoch 024] Loss: 39065178112.000000[Epoch 025] Loss: 39065178112.000000[Epoch 026] Loss: 39065178112.000000[Epoch 027] Loss: 39065178112.000000[Epoch 028] Loss: 39065178112.000000[Epoch 029] Loss: 39065178112.000000[Epoch 030] Loss: 39065178112.000000An `nn_module` containing 10,742 parameters.

── Modules ──────────────────────────────────────────────────────────────────────────────────────────────
• embedder: <embedding_generator> #283 parameters
• embedder_na: <na_embedding_generator> #0 parameters
• tabnet: <tabnet_no_embedding> #10,458 parameters

── Parameters ───────────────────────────────────────────────────────────────────────────────────────────
• .check: Float [1:1]

However, if I manually change device to "cpu" or "gpu", it works. I don't have to use gpu as I'm using a small dataset. However, I don't know how to set the device if I use fit_resamples function from tidymodels.

tabnet_fit(Sale_Price ~. , data = ames,
                  config = tabnet_config(learn_rate = 0.02, device = "gpu", verbose = TRUE, epoch = 30))
[Epoch 001] Loss: 39066984448.000000[Epoch 002] Loss: 39066828800.000000[Epoch 003] Loss: 39066701824.000000[Epoch 004] Loss: 39066599424.000000[Epoch 005] Loss: 39066472448.000000[Epoch 006] Loss: 39066333184.000000[Epoch 007] Loss: 39066202112.000000[Epoch 008] Loss: 39066054656.000000[Epoch 009] Loss: 39065935872.000000[Epoch 010] Loss: 39065735168.000000[Epoch 011] Loss: 39065550848.000000[Epoch 012] Loss: 39065358336.000000[Epoch 013] Loss: 39065174016.000000[Epoch 014] Loss: 39064932352.000000[Epoch 015] Loss: 39064719360.000000[Epoch 016] Loss: 39064481792.000000[Epoch 017] Loss: 39064223744.000000[Epoch 018] Loss: 39063965696.000000[Epoch 019] Loss: 39063650304.000000[Epoch 020] Loss: 39063351296.000000[Epoch 021] Loss: 39063044096.000000[Epoch 022] Loss: 39062691840.000000[Epoch 023] Loss: 39062327296.000000[Epoch 024] Loss: 39061958656.000000[Epoch 025] Loss: 39061561344.000000[Epoch 026] Loss: 39061139456.000000[Epoch 027] Loss: 39060705280.000000[Epoch 028] Loss: 39060262912.000000[Epoch 029] Loss: 39059787776.000000[Epoch 030] Loss: 39059300352.000000An `nn_module` containing 10,742 parameters.

── Modules ──────────────────────────────────────────────────────────────────────────────────────────────
• embedder: <embedding_generator> #283 parameters
• embedder_na: <na_embedding_generator> #0 parameters
• tabnet: <tabnet_no_embedding> #10,458 parameters

── Parameters ───────────────────────────────────────────────────────────────────────────────────────────
• .check: Float [1:1]

ruddnr avatar Aug 17 '22 04:08 ruddnr

Hello @ruddnr,

The behavior is:device="auto" is resolved into device="cuda" as soon as cuda is available. There is no magic here. So there should be no change in the global behavior out of that. You may have hit a local minimum in one of the two cases, based maybe on the difference in random initialization. Have you fixed all the seeds ? Another clue is to check if the validation set suffer the same problem (using valid_split = 0.2 in the configuration) ?

cregouby avatar Aug 17 '22 08:08 cregouby