synthcity
synthcity copied to clipboard
TabDDPM generating NaNs for large datasets
Question
When fitting TabDDPM for more than a single iteration NaNs are generated which leads to a ValueError
during sampling here: https://github.com/vanderschaarlab/synthcity/blob/41e6e5acfd886dd4ebc0528039e9395a2a93b380/src/synthcity/plugins/core/models/tabular_ddpm/gaussian_multinomial_diffsuion.py#L954-L955
Further Information
I'm fitting TabDDPM to a dataset of ~150k rows and ~1000 columns. All columns are numeric, contain no NaNs, and are scaled to z-scores. When training for more than a single iteration, I get the above error when calling model.generate(count=count, cond=cond)
. Any ideas about what might be happening? Seems to happen no matter what the other parameters are set to.
After subsampling to 15k rows, I was able to train for 1000 iterations, but only up to 500 timesteps. Might be some issues with larger datasets?
System Information
- OS: Windows server (😢 )
- Language Version: Python 3.10.9
Hi @HLasse,
I have not seen this issue before. Have you tried experimenting with batch size to see if that get round the issue?
One other thing to check would be numerical instability. There are a few divisions and logarithms in the code here do any the denominators/log arguments of them approach zero for your dataset?
I actually encountered the same problem on larger datasets. For me, it is not just that NaNs are sampled but the training loss also becomes NaN after a couple of iterations.
I agree with @robsdavis that this is related to numerical instability. In particular, I traced the error (in my case) down to
https://github.com/vanderschaarlab/synthcity/blob/943fa280687d236d783e53f40302838f5924f422/src/synthcity/plugins/core/models/tabular_ddpm/utils.py#L151-L154
I managed to stabilize the function and can create a related pull request if you want. However, I cannot guarantee that this gives the same results on the datasets for which the non-adjusted variant works without issues. I only tried this for a couple of datasets and the results were close enough.
I also use Tabddpm,
ValueError: found NaNs in sample
Error occurs