SDV
SDV copied to clipboard
When using `CTGANSynthesizer`, ideally returning loss values didn't return PyTorch objects
Problem Description
After fitting a model using CTGANSynthesizer, calling loss_values returns a DataFrame object where the loss values are PyTorch tensor objects instead of just simple float values.
This means that plotting these values requires an extra step of extracting the values using apply(), which adds unnecessary friction I feel.
Expected behavior
Ideally the returned DataFrame just had float values for Generator & Discriminator loss values. This lowers the friction for plotting the loss values:
loss_df = ctgan.loss_values
loss_df.plot(x='Epoch', y=['Generator Loss', 'Discriminator Loss'])
Additional context
The CTGANSynthesizer class is just passing the loss values directly from the underlying PyTorch model: https://github.com/sdv-dev/CTGAN/blob/main/ctgan/synthesizers/ctgan.py#L426