CTGAN icon indicating copy to clipboard operation
CTGAN copied to clipboard

discrete_column_matrix_st from data_sampler class is always 0

Open tejuafonja opened this issue 4 years ago • 4 comments

Hello,

I noticed that discrete_column_matrix_st from the data_sampler class is not updated and it's always 0. If this is the case, then the function sample_original_condvec(self, batch) would not be picking the correct category of the discrete column. I'm not sure if this is the desired behavior. I'm of the opinion that the discrete_column_matrix_st is supposed to help keep track of the discrete column start position in the data matrix. If my assumption were to be correct, updating the matrix with self._discrete_column_matrix_st[current_id] = st as done below should solve the issue. Please feel free to correct my assumption if otherwise and thank you for open-sourcing your code.

# excerpt: data_sampler.py, line 60-78
for column_info in output_info:
            if is_discrete_column(column_info):
                span_info = column_info[0]
                ed = st + span_info.dim
                category_freq = np.sum(data[:, st:ed], axis=0)
                if log_frequency:
                    category_freq = np.log(category_freq + 1)
                category_prob = category_freq / np.sum(category_freq)
                self._discrete_column_category_prob[current_id, :span_info.dim] = (
                    category_prob)
                self._discrete_column_cond_st[current_id] = current_cond_st
                self._discrete_column_n_category[current_id] = span_info.dim

                # **updated self._discrete_column_matrix_st**
                self._discrete_column_matrix_st[current_id] = st

                current_cond_st += span_info.dim
                current_id += 1
                st = ed
            else:
                st += sum([span_info.dim for span_info in column_info])

tejuafonja avatar Aug 27 '21 10:08 tejuafonja

I agree that this needs attention. Besides sample_original_condvec, generate_cond_from_condition_column_info is also affected. In the current state, the condition vectors created are not correct. Thanks @tejuafonja for suggesting a fix!

JonathanDZiegler avatar Apr 28 '22 07:04 JonathanDZiegler

The fix seems to have an issue with mixed tables. Your code also counts continuous columns, which is not correct. A simple expansion to two counters (one overall and one for discrete columns) seems to do the trick:

st = 0
discrete_st = 0
current_id = 0
current_cond_st = 0
for column_info in output_info:
    if is_discrete_column(column_info):
        span_info = column_info[0]
        ed = st + span_info.dim
        discrete_ed = discrete_st + span_info.dim
        category_freq = np.sum(data[:, st:ed], axis=0)
        if log_frequency:
            category_freq = np.log(category_freq + 1)
        category_prob = category_freq / np.sum(category_freq)
        self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob
        self._discrete_column_cond_st[current_id] = current_cond_st
        self._discrete_column_n_category[current_id] = span_info.dim

        # **updated self._discrete_column_matrix_st**
        self._discrete_column_matrix_st[current_id] = discrete_st

        current_cond_st += span_info.dim
        current_id += 1
        st = ed
        discrete_st = discrete_ed
    else:
        st += sum([span_info.dim for span_info in column_info])

JonathanDZiegler avatar Apr 28 '22 08:04 JonathanDZiegler

Thanks for improving the suggested fix!

tejuafonja avatar May 01 '22 10:05 tejuafonja

Hi everyone, I believe that the initial solution from @tejuafonja is correct. See #236 for more details on this, and feel free to give some feedback there!

AndresAlgaba avatar Jul 25 '22 08:07 AndresAlgaba