CTGAN
CTGAN copied to clipboard
Fix bugs for conditional sampling
Hi everyone, this PR fixes issues #169 and #235 which report bugs concerning the sampling from the conditional generator after training, i.e., the sample method of CTGAN. The details of the proposed changes are described and discussed in the issues, but I give a summary here:
- Issue #169 concerns the
_discrete_column_matrix_stof theDataSamplerinCTGAN. It affects thesample_original_condvecandgenerate_cond_from_condition_column_infomethods. Addingself._discrete_column_matrix_st[current_id] = stfixes the issue forsample_original_condvec. To fix the issue forgenerate_cond_from_condition_column_info, I have replaced_discrete_column_matrix_stwith_discrete_column_cond_st. The difference between both fixes is due to creating a conditional vector vs. selecting a conditional vector from the data (which also contains continuous variables and thus requires other indices). - Issue #235 was only partially fixed by setting
_discrete_column_matrix_stto_discrete_column_cond_st. There were still some issues as the generator contains batchnorm layers, and the model was still intrainmode. Settingself._generator.eval()fixed the issue here. For performance, I also added thewith torch.no_grad(). - I have written
test_synthesizer_samplingto test the sampling methods. I noticed thattest_log_frequencywas failing, but after looking into more detail, it seems this test is outdated #20. The generator's sampling during inference time is always set to the empirical frequency (not sure whether this is intentional, and maybe an issue to request the feature to sample with log frequency may be appropriate?). In training, the default option is the log frequency, but this is not what the test is assessing. Therefore, I have changed this test, but it can also be removed.