CTGAN icon indicating copy to clipboard operation
CTGAN copied to clipboard

Fix bugs for conditional sampling

Open AndresAlgaba opened this issue 3 years ago • 0 comments

Hi everyone, this PR fixes issues #169 and #235 which report bugs concerning the sampling from the conditional generator after training, i.e., the sample method of CTGAN. The details of the proposed changes are described and discussed in the issues, but I give a summary here:

  • Issue #169 concerns the _discrete_column_matrix_st of the DataSampler in CTGAN. It affects the sample_original_condvec and generate_cond_from_condition_column_info methods. Adding self._discrete_column_matrix_st[current_id] = st fixes the issue for sample_original_condvec. To fix the issue for generate_cond_from_condition_column_info, I have replaced _discrete_column_matrix_st with _discrete_column_cond_st. The difference between both fixes is due to creating a conditional vector vs. selecting a conditional vector from the data (which also contains continuous variables and thus requires other indices).
  • Issue #235 was only partially fixed by setting _discrete_column_matrix_st to _discrete_column_cond_st. There were still some issues as the generator contains batchnorm layers, and the model was still in train mode. Setting self._generator.eval() fixed the issue here. For performance, I also added the with torch.no_grad().
  • I have written test_synthesizer_sampling to test the sampling methods. I noticed that test_log_frequency was failing, but after looking into more detail, it seems this test is outdated #20. The generator's sampling during inference time is always set to the empirical frequency (not sure whether this is intentional, and maybe an issue to request the feature to sample with log frequency may be appropriate?). In training, the default option is the log frequency, but this is not what the test is assessing. Therefore, I have changed this test, but it can also be removed.

AndresAlgaba avatar Jul 20 '22 12:07 AndresAlgaba