synthcity icon indicating copy to clipboard operation
synthcity copied to clipboard

TabDDPM generating NaNs for large datasets

Open HLasse opened this issue 10 months ago • 2 comments

Question

When fitting TabDDPM for more than a single iteration NaNs are generated which leads to a ValueError during sampling here: https://github.com/vanderschaarlab/synthcity/blob/41e6e5acfd886dd4ebc0528039e9395a2a93b380/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py#L954-L955

Further Information

I'm fitting TabDDPM to a dataset of ~150k rows and ~1000 columns. All columns are numeric, contain no NaNs, and are scaled to z-scores. When training for more than a single iteration, I get the above error when calling model.generate(count=count, cond=cond). Any ideas about what might be happening? Seems to happen no matter what the other parameters are set to.

After subsampling to 15k rows, I was able to train for 1000 iterations, but only up to 500 timesteps. Might be some issues with larger datasets?

System Information

  • OS: Windows server (😢 )
  • Language Version: Python 3.10.9

HLasse avatar Apr 05 '24 11:04 HLasse

Hi @HLasse,

I have not seen this issue before. Have you tried experimenting with batch size to see if that get round the issue?

One other thing to check would be numerical instability. There are a few divisions and logarithms in the code here do any the denominators/log arguments of them approach zero for your dataset?

robsdavis avatar Apr 11 '24 14:04 robsdavis

I actually encountered the same problem on larger datasets. For me, it is not just that NaNs are sampled but the training loss also becomes NaN after a couple of iterations.

I agree with @robsdavis that this is related to numerical instability. In particular, I traced the error (in my case) down to

https://github.com/vanderschaarlab/synthcity/blob/943fa280687d236d783e53f40302838f5924f422/src/synthcity/plugins/core/models/tabular_ddpm/utils.py#L151-L154

I managed to stabilize the function and can create a related pull request if you want. However, I cannot guarantee that this gives the same results on the datasets for which the non-adjusted variant works without issues. I only tried this for a couple of datasets and the results were close enough.

muellermarkus avatar May 27 '24 16:05 muellermarkus

I also use Tabddpm,

ValueError: found NaNs in sample

Error occurs

limhasic avatar Oct 23 '24 07:10 limhasic