CTGAN icon indicating copy to clipboard operation
CTGAN copied to clipboard

Conditional sampling and cross-entropy loss

Open AndresAlgaba opened this issue 3 years ago • 4 comments

Hi everyone! I have a question/problem regarding the conditional sampling in the sample method of the CTGANSynthesizer using the condition_column and condition_value arguments. For example, derived from the Usage Example in the README:

samples = ctgan.sample(1000, condition_column='sex', condition_value=' Male')

Note that the whitespace in condition_value=' Male' is intentional, see #233 and #234.

Environment Details

  • CTGAN version: latest (0.5.2.dev1)
  • Python version: 3.9.7
  • Operating System: Windows

Problem description

Intuitively, it seems that when a model is sufficiently trained, the conditional sampling should (almost) only generate examples satisfying the criteria given by the conditional vector. To monitor whether this is happening during training, I've printed the cross-entropy loss as follows:

if self._verbose:
    print(f'Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f},'  # noqa: T001
          f'Loss D: {loss_d.detach().cpu(): .4f}',
          f'Cross Entropy: {cross_entropy.detach().cpu(): .4f}',
          flush=True)

https://github.com/sdv-dev/CTGAN/blob/5358af7cd653eb0c3a96f9671c90fbdde9672f45/ctgan/synthesizers/ctgan.py#L419

The cross-entropy loss rapidly approaches zero, indicating that the generated examples satisfy the conditional vector criteria during training.

However, when sampling with the sample method, the generated samples do not satisfy the criteria substantially more than when no criteria are given (and thus, the empirical distribution is used). I could not find any issues in the code, and was wondering whether my intuition was wrong?

What I already tried

from ctgan import CTGANSynthesizer
from ctgan import load_demo

data = load_demo()

discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income'
]

ctgan = CTGANSynthesizer(epochs=100, verbose=True)
ctgan.fit(data, discrete_columns)

# conditional
samples = ctgan.sample(1000, condition_column='sex', condition_value=' Male')
samples["sex"].value_counts().plot(kind='bar')

# unconditional
samples = ctgan.sample(1000)
samples["sex"].value_counts().plot(kind='bar')

I have also done a similar analysis using the test example: https://github.com/sdv-dev/CTGAN/blob/5358af7cd653eb0c3a96f9671c90fbdde9672f45/tests/integration/synthesizer/test_ctgan.py#L123 but reached similar results.

AndresAlgaba avatar Jul 19 '22 11:07 AndresAlgaba

Upon further inspection, I believe there may be a problem in:

if condition_column is not None and condition_value is not None:
    condition_info = self._transformer.convert_column_name_value_to_id(
        condition_column, condition_value)
    global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
        condition_info, self._batch_size)

https://github.com/sdv-dev/CTGAN/blob/5358af7cd653eb0c3a96f9671c90fbdde9672f45/ctgan/synthesizers/ctgan.py#L443

The condition_column always appears as the first n columns of the global_condition_vec. For example, condition_column='sex', condition_value=' Male' and condition_column='workclass', condition_value=' State-gov' lead to the same global_condition_vec. They are both the first category of their respective discrete variable. I will further look into this.

AndresAlgaba avatar Jul 19 '22 14:07 AndresAlgaba

Update: I believe a problem may reside in the generate_cond_from_condition_column_info of the DataSampler.

def generate_cond_from_condition_column_info(self, condition_info, batch):
    """Generate the condition vector."""
    vec = np.zeros((batch, self._n_categories), dtype='float32')
    id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]
    id_ += condition_info['value_id']
    vec[:, id_] = 1
    return vec

https://github.com/sdv-dev/CTGAN/blob/5358af7cd653eb0c3a96f9671c90fbdde9672f45/ctgan/data_sampler.py#L153

Specifically, the _discrete_column_matrix_st attribute is initialized as:

self._discrete_column_matrix_st = np.zeros(n_discrete_columns, dtype='int32')

And does not seem to be changed afterward.

Therefore:

id_ = self._discrete_column_matrix_st[condition_info['discrete_column_id']]

Will always return id_ zero.

I believe _discrete_column_cond_st has to be used instead of _discrete_column_matrix_st. This seems to generate the appropriate global_condition_vec.

However, this does not seem to solve the initial issue. I will look further into the conditional generation part, which was my main issue.


I noticed that matrix_st = self._discrete_column_matrix_st[col_idx] is also used here: https://github.com/sdv-dev/CTGAN/blob/5358af7cd653eb0c3a96f9671c90fbdde9672f45/ctgan/data_sampler.py#L123 And again, I believe that matrix_st will always be zero. I am not sure whether this may cause any unwanted behavior?

AndresAlgaba avatar Jul 19 '22 14:07 AndresAlgaba

Hi @AndresAlgaba, thanks for filing and looking to this. I just wanted to confirm that we've seen this.

We can update this issue when we have more bandwidth to debug. If you do end up finding the root cause, please let us know!

BTW What is your overall use case for conditional sampling / synthetic data? Even if this conditional vector manipulation may not be working as intended, you can still use a reject sampling-based approach (synthesizing data without any conditions and then throwing way rows you don't need). The SDV library provides convenience wrappers around CTGAN to help you do exactly this. This User Guide may be helpful, particularly the conditional sampling section.

npatki avatar Jul 19 '22 21:07 npatki

Hi @npatki, no problem, and thanks for the confirmation!

Besides the change from _discrete_column_matrix_st to _discrete_column_cond_st (as mentioned above): https://github.com/sdv-dev/CTGAN/blob/5358af7cd653eb0c3a96f9671c90fbdde9672f45/ctgan/data_sampler.py#L153 (By the way, I found issue #169 talking about a similar issue with _discrete_column_matrix_st).

I have found that proper sampling requires the generator to be put in evaluation mode:

self._generator.eval()

As batch normalization is used in the generator.

I have opened a PR with the proposed changes #236.

Thank you for the suggestion on the SDV library! An issue (sdv-dev/SDV#623) brought me to examine the conditional sampling :).

AndresAlgaba avatar Jul 20 '22 12:07 AndresAlgaba