tsai icon indicating copy to clipboard operation
tsai copied to clipboard

NameError: name 'GatedTabTransformer' is not defined

Open zmce2018 opened this issue 2 years ago • 2 comments

Hi Team,

I have been running experiments and trying different models. I have found that I was not able to use GatedTabTransformer model

My environment is:

os             : Windows-10-10.0.19043-SP0
python         : 3.8.13
tsai           : 0.3.1
fastai         : 2.6.3
fastcore       : 1.4.3
optuna         : 2.10.0
torch          : 1.11.0
device         : 1 gpu (['NVIDIA GeForce GTX 1080 Ti'])
cpu cores      : 12
RAM            : 31.95 GB
GPU memory     : [11.0] GB

Code I was trying to run:

learn = TSRegressor(X, y, splits=splits, bs=[64], batch_tfms=batch_tfms,
                     arch="GatedTabTransformer", loss_func = mae,
                     metrics=mae,device='cuda',cbs=ShowGraph())

Error is here:

  File F:\CNNpred\tsai_TST.py:30 in <module>
    arch=GatedTabTransformer, loss_func = mae,

NameError: name 'GatedTabTransformer' is not defined

I have tested other models okay.

zmce2018 avatar Jun 06 '22 06:06 zmce2018

From the error message I can infer that the arch parameter in your code is set to arch=GatedTabTransformer and not arch="GatedTabTransformer". It seems like the GatedTabTransformer is not imported. So you can add

from tsai.models.GatedTabTransformer import GatedTabTransformer

However, I can confirm that when you do from tsai.all import * the TabTransformer is imported but the GatedTabTransformer is not and that should be fixed.

radi-cho avatar Jul 08 '22 09:07 radi-cho

Hi, First of all, sorry for my late reply. There's actually a reason why this model cannot be used as others in TSRegressor. This is not a time series model per se. It's a tabular model that needs to be instantiated in a different way. That's the reason why this model needs to be instantiated outside the learner and then passed as a model (instead of an architecture).

However, I can confirm that when you do from tsai.all import * the TabTransformer is imported but the GatedTabTransformer is not and that should be fixed.

This was indeed an issue. I've fixed it now.

oguiza avatar Aug 18 '22 10:08 oguiza