tabnet icon indicating copy to clipboard operation
tabnet copied to clipboard

Validation loss for pretraining

Open hamaadshah opened this issue 8 months ago • 12 comments

Hi there: Big fan of the project. I had one question maybe some of you have come across it. I have been able to successfully lower the pre training loss to below 1. Job done :) but the validation loss appears to just be annoyingly hovering above 1 with values like 1.01 etc. I've tried increasing the capacity of the model, weight decay, etc. To no avail. It just gets stuck around 1.01.

Any thoughts or suggestions?

hamaadshah avatar Mar 18 '25 10:03 hamaadshah

Pretraining can be tricky, see my here.

Loss is also very dependent on dataset distribution (deviation between samples inside dataset), unfortunately I don't have much tips to share. Please share if you do manage to lower your validation loss

Optimox avatar Mar 20 '25 09:03 Optimox

Pretraining can be tricky, see my here.

Loss is also very dependent on dataset distribution (deviation between samples inside dataset), unfortunately I don't have much tips to share. Please share if you do manage to lower your validation loss

could this be due to how the validation losses and train losses are calculated? As in the former you guys use the full validation data via Numpy CPU etc (memory issues and all on GPU rightly so - I read that issue) and the train loss is on a per batch basis? 1.01 isn't bad I think it is within the margin og reasonable error for being an acceptable validation pretraining loss - I think?

hamaadshah avatar Mar 20 '25 09:03 hamaadshah

Above 1.0 is still a bit disappointing even though it seems your pretraining is not doing worse than random, which is still something... The only way to monitor if the pretraining is useful is to compute CV on your supervised dataset with and without pretraining and see whether you have an actual improvement.

Optimox avatar Mar 20 '25 09:03 Optimox

Above 1.0 is still a bit disappointing even though it seems your pretraining is not doing worse than random, which is still something... The only way to monitor if the pretraining is useful is to compute CV on your supervised dataset with and without pretraining and see whether you have an actual improvement.

Also perhaps a naive question but would you recommend focusing on n_a/n_d or n_steps or increasing the number of independent / shared units etc I appreciate the architecture but would love to know from folks who may have more experience than I do in this archiecture as to what may work to say narrow the gap between train / validation - sorry to bother

hamaadshah avatar Mar 20 '25 17:03 hamaadshah

n_d, n_a and n_steps are definitely the most important parameters of the architecture. But learning rate (and batch size) also plays a great role on how the training goes, I would personally start by playing with batch size and obfuscation factor for the pretraining.

Optimox avatar Mar 22 '25 07:03 Optimox

n_d, n_a and n_steps are definitely the most important parameters of the architecture. But learning rate (and batch size) also plays a great role on how the training goes, I would personally start by playing with batch size and obfuscation factor for the pretraining.

Many thanks for all your advice :) one potentially silly question - I of course use the train set to get the train loss and all that good stuff but I had some doubts on how the validation loss is calculated so just confirm my suspicions or otherwise I put the train set in the validation metrics too - and strangely they are different - it's the same data set - is this because the loss for training is batch wise and the loss for validation / test etc is full data set? As in it doesn't have any batching? I'll take a closer look at the internals of your code too - surely I have missed something

hamaadshah avatar Mar 25 '25 20:03 hamaadshah

Seems to be that forward method in the pretraining piece where if you are in eval mode for inference then you don't use any obfuscation, etc., and just a tensor of ones - issue is that seems to feed into a deviation between training loss and validation loss - perhaps we can use the obfuscation group instead of just the tensor of ones to calculate said validaiton loss? I could be wildly mistaken and I apologise but there is a strange deviation between what I expected the model to do in training and validation when actaully it is working as it should do (strangely)

hamaadshah avatar Mar 25 '25 21:03 hamaadshah

Could you indeed monitor both your training and validation set so that you can compare the same metric, not the loss vs the metric?

Optimox avatar Mar 26 '25 08:03 Optimox

Could you indeed monitor both your training and validation set so that you can compare the same metric, not the loss vs the metric?

I think I may have found the issue - so looking back at the original paper:

Image

If you see there they recommend to use a global standard deviation to normalize. Makes sense: And it also makes sense to use a per batch basis one, etc.

Now when I convert all the loss metrics, etc., to using this global standard deviation as opposed to the batch based ones we get what we expect :) losses well below 1.0 - well below. Of course taking care to use only the training fold standard deviations as that is all the model knows at that point.

I also removed that batch based standard deviation missing value imputation that is done:

batch_means = torch.mean(embedded_x, dim=0)
batch_means[batch_means == 0] = 1

batch_stds = torch.std(embedded_x, dim=0) ** 2
batch_stds[batch_stds == 0] = batch_means[batch_stds == 0]

I think there was another issue (https://github.com/dreamquark-ai/tabnet/issues/515) that referenced this problem too where the above was causing the loss to explode or rather not be so healthy. Note that I also remove any feature that has constant variance :) that's just the bias feature :)

So in a nutshell all I have done is replace batch based standard deviations above with global training fold based standard deviations as recommended by the original paper too.

Unless I am wildly mistaken I think I got it :) would be great to hear your thoughts as well unless I have missed something - I don't think I have but that is science / engineering :)

If you may agree I might put in a PR for this :)

P.S., also sorry not to nitpick and all but should this be without the power exponent? batch_stds = torch.std(embedded_x, dim=0) ** 2 - such that it is indeed standard deviation and not variance - forgive me if I am mistaken - oh I just re-read the formula I think I see where that is coming from :)

hamaadshah avatar Mar 26 '25 18:03 hamaadshah

Feel free to open a PR that I'll make sure to review carefully.

Optimox avatar Mar 31 '25 14:03 Optimox

@hamaadshah any update on this ?

Optimox avatar Apr 22 '25 08:04 Optimox

@hamaadshah any update on this ?

Hi @Optimox: Apologies for the late response was just busy at work :) bear with me I will create a PR for this - run some experiments on some public datasets and await your approval :)

hamaadshah avatar Apr 23 '25 13:04 hamaadshah

@hamaadshah any update on this ?

Hi @Optimox: Apologies for the late response was just busy at work :) bear with me I will create a PR for this - run some experiments on some public datasets and await your approval :)

P.S., @Optimox - hello and sorry for the radio silence :) been busy with my contract work which has now concluded hence I have a fair bit more time :) am looking into raising a nice PR for said issue - I have a fair few other improvements I believe that could be great for the project (those are perhaps other PRs) - step by step :) bear with me

hamaadshah avatar Oct 29 '25 15:10 hamaadshah