SDV icon indicating copy to clipboard operation
SDV copied to clipboard

When using `CTGANSynthesizer`, ideally returning loss values didn't return PyTorch objects

Open srinify opened this issue 1 year ago • 0 comments

Problem Description

After fitting a model using CTGANSynthesizer, calling loss_values returns a DataFrame object where the loss values are PyTorch tensor objects instead of just simple float values.

Screenshot 2024-02-22 at 4 20 36 PM

This means that plotting these values requires an extra step of extracting the values using apply(), which adds unnecessary friction I feel.

Screenshot 2024-02-22 at 4 58 56 PM

Expected behavior

Ideally the returned DataFrame just had float values for Generator & Discriminator loss values. This lowers the friction for plotting the loss values:

loss_df = ctgan.loss_values
loss_df.plot(x='Epoch', y=['Generator Loss', 'Discriminator Loss'])

Additional context

The CTGANSynthesizer class is just passing the loss values directly from the underlying PyTorch model: https://github.com/sdv-dev/CTGAN/blob/main/ctgan/synthesizers/ctgan.py#L426

srinify avatar Feb 23 '24 15:02 srinify