CTGAN
CTGAN copied to clipboard
discrete_column_matrix_st from data_sampler class is always 0
Hello,
I noticed that discrete_column_matrix_st from the data_sampler class is not updated and it's always 0. If this is the case, then the function sample_original_condvec(self, batch) would not be picking the correct category of the discrete column. I'm not sure if this is the desired behavior. I'm of the opinion that the discrete_column_matrix_st is supposed to help keep track of the discrete column start position in the data matrix. If my assumption were to be correct, updating the matrix with self._discrete_column_matrix_st[current_id] = st as done below should solve the issue. Please feel free to correct my assumption if otherwise and thank you for open-sourcing your code.
# excerpt: data_sampler.py, line 60-78
for column_info in output_info:
if is_discrete_column(column_info):
span_info = column_info[0]
ed = st + span_info.dim
category_freq = np.sum(data[:, st:ed], axis=0)
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
self._discrete_column_category_prob[current_id, :span_info.dim] = (
category_prob)
self._discrete_column_cond_st[current_id] = current_cond_st
self._discrete_column_n_category[current_id] = span_info.dim
# **updated self._discrete_column_matrix_st**
self._discrete_column_matrix_st[current_id] = st
current_cond_st += span_info.dim
current_id += 1
st = ed
else:
st += sum([span_info.dim for span_info in column_info])
I agree that this needs attention. Besides sample_original_condvec, generate_cond_from_condition_column_info is also affected. In the current state, the condition vectors created are not correct. Thanks @tejuafonja for suggesting a fix!
The fix seems to have an issue with mixed tables. Your code also counts continuous columns, which is not correct. A simple expansion to two counters (one overall and one for discrete columns) seems to do the trick:
st = 0
discrete_st = 0
current_id = 0
current_cond_st = 0
for column_info in output_info:
if is_discrete_column(column_info):
span_info = column_info[0]
ed = st + span_info.dim
discrete_ed = discrete_st + span_info.dim
category_freq = np.sum(data[:, st:ed], axis=0)
if log_frequency:
category_freq = np.log(category_freq + 1)
category_prob = category_freq / np.sum(category_freq)
self._discrete_column_category_prob[current_id, :span_info.dim] = category_prob
self._discrete_column_cond_st[current_id] = current_cond_st
self._discrete_column_n_category[current_id] = span_info.dim
# **updated self._discrete_column_matrix_st**
self._discrete_column_matrix_st[current_id] = discrete_st
current_cond_st += span_info.dim
current_id += 1
st = ed
discrete_st = discrete_ed
else:
st += sum([span_info.dim for span_info in column_info])
Thanks for improving the suggested fix!
Hi everyone, I believe that the initial solution from @tejuafonja is correct. See #236 for more details on this, and feel free to give some feedback there!