tabnet
tabnet copied to clipboard
Validation loss for pretraining
Hi there: Big fan of the project. I had one question maybe some of you have come across it. I have been able to successfully lower the pre training loss to below 1. Job done :) but the validation loss appears to just be annoyingly hovering above 1 with values like 1.01 etc. I've tried increasing the capacity of the model, weight decay, etc. To no avail. It just gets stuck around 1.01.
Any thoughts or suggestions?
Pretraining can be tricky, see my here.
Loss is also very dependent on dataset distribution (deviation between samples inside dataset), unfortunately I don't have much tips to share. Please share if you do manage to lower your validation loss
Pretraining can be tricky, see my here.
Loss is also very dependent on dataset distribution (deviation between samples inside dataset), unfortunately I don't have much tips to share. Please share if you do manage to lower your validation loss
could this be due to how the validation losses and train losses are calculated? As in the former you guys use the full validation data via Numpy CPU etc (memory issues and all on GPU rightly so - I read that issue) and the train loss is on a per batch basis? 1.01 isn't bad I think it is within the margin og reasonable error for being an acceptable validation pretraining loss - I think?
Above 1.0 is still a bit disappointing even though it seems your pretraining is not doing worse than random, which is still something... The only way to monitor if the pretraining is useful is to compute CV on your supervised dataset with and without pretraining and see whether you have an actual improvement.
Above 1.0 is still a bit disappointing even though it seems your pretraining is not doing worse than random, which is still something... The only way to monitor if the pretraining is useful is to compute CV on your supervised dataset with and without pretraining and see whether you have an actual improvement.
Also perhaps a naive question but would you recommend focusing on n_a/n_d or n_steps or increasing the number of independent / shared units etc I appreciate the architecture but would love to know from folks who may have more experience than I do in this archiecture as to what may work to say narrow the gap between train / validation - sorry to bother
n_d, n_a and n_steps are definitely the most important parameters of the architecture. But learning rate (and batch size) also plays a great role on how the training goes, I would personally start by playing with batch size and obfuscation factor for the pretraining.
n_d, n_a and n_steps are definitely the most important parameters of the architecture. But learning rate (and batch size) also plays a great role on how the training goes, I would personally start by playing with batch size and obfuscation factor for the pretraining.
Many thanks for all your advice :) one potentially silly question - I of course use the train set to get the train loss and all that good stuff but I had some doubts on how the validation loss is calculated so just confirm my suspicions or otherwise I put the train set in the validation metrics too - and strangely they are different - it's the same data set - is this because the loss for training is batch wise and the loss for validation / test etc is full data set? As in it doesn't have any batching? I'll take a closer look at the internals of your code too - surely I have missed something
Seems to be that forward method in the pretraining piece where if you are in eval mode for inference then you don't use any obfuscation, etc., and just a tensor of ones - issue is that seems to feed into a deviation between training loss and validation loss - perhaps we can use the obfuscation group instead of just the tensor of ones to calculate said validaiton loss? I could be wildly mistaken and I apologise but there is a strange deviation between what I expected the model to do in training and validation when actaully it is working as it should do (strangely)
Could you indeed monitor both your training and validation set so that you can compare the same metric, not the loss vs the metric?
Could you indeed monitor both your training and validation set so that you can compare the same metric, not the loss vs the metric?
I think I may have found the issue - so looking back at the original paper:
If you see there they recommend to use a global standard deviation to normalize. Makes sense: And it also makes sense to use a per batch basis one, etc.
Now when I convert all the loss metrics, etc., to using this global standard deviation as opposed to the batch based ones we get what we expect :) losses well below 1.0 - well below. Of course taking care to use only the training fold standard deviations as that is all the model knows at that point.
I also removed that batch based standard deviation missing value imputation that is done:
batch_means = torch.mean(embedded_x, dim=0)
batch_means[batch_means == 0] = 1
batch_stds = torch.std(embedded_x, dim=0) ** 2
batch_stds[batch_stds == 0] = batch_means[batch_stds == 0]
I think there was another issue (https://github.com/dreamquark-ai/tabnet/issues/515) that referenced this problem too where the above was causing the loss to explode or rather not be so healthy. Note that I also remove any feature that has constant variance :) that's just the bias feature :)
So in a nutshell all I have done is replace batch based standard deviations above with global training fold based standard deviations as recommended by the original paper too.
Unless I am wildly mistaken I think I got it :) would be great to hear your thoughts as well unless I have missed something - I don't think I have but that is science / engineering :)
If you may agree I might put in a PR for this :)
P.S., also sorry not to nitpick and all but should this be without the power exponent? batch_stds = torch.std(embedded_x, dim=0) ** 2 - such that it is indeed standard deviation and not variance - forgive me if I am mistaken - oh I just re-read the formula I think I see where that is coming from :)
Feel free to open a PR that I'll make sure to review carefully.
@hamaadshah any update on this ?
@hamaadshah any update on this ?
Hi @Optimox: Apologies for the late response was just busy at work :) bear with me I will create a PR for this - run some experiments on some public datasets and await your approval :)
@hamaadshah any update on this ?
Hi @Optimox: Apologies for the late response was just busy at work :) bear with me I will create a PR for this - run some experiments on some public datasets and await your approval :)
P.S., @Optimox - hello and sorry for the radio silence :) been busy with my contract work which has now concluded hence I have a fair bit more time :) am looking into raising a nice PR for said issue - I have a fair few other improvements I believe that could be great for the project (those are perhaps other PRs) - step by step :) bear with me