SDV icon indicating copy to clipboard operation
SDV copied to clipboard

Access to underlying model weights of CTGAN and TVAE for manipulation

Open MakGulati opened this issue 4 years ago • 3 comments

Problem Description

Ability to access underlying deep learning model written in pytorch. It would be nice to have API function call to read and write weights of network (e.g. generator and discriminator network in case of CTGAN).

Expected behavior

API access to variables in CTGAN package through SDV package.

Additional context

Having such feature would allow to train federated learning models. After training, deep learning networks on different local clients with different local datasets in SDV using fit method the model weights can be read (through this new API) and sent for aggregation at global server and then global server sends the aggregated model which is then loaded on the selected clients at each round. Following that the local training starts again and the process repeats. Finally, after training for desired number of rounds, SDV's sample method can be used to generate synthetic datasets with underlying aggregated deep learning model which captures behaviors from different dateset holding local clients.

MakGulati avatar Jan 12 '22 07:01 MakGulati

Thank you for filing & describing your use case @MakGulati. We'll keep this issue open and update it whenever we make progress.

An unsupported & hacky workaround you can try in the meantime: You can access the generator model using the model._generator. This will get you PyTorch Module object. From there on out, you'd have to refer to the PyTorch user guides & API to extract any desired parameters.

npatki avatar Jan 12 '22 21:01 npatki

I would also need access to discriminator. As it is a local variable now, I cannot access it outside the class. @npatki With change in the dimensions of discriminator, it is also reload discriminator.

MakGulati avatar Jan 19 '22 05:01 MakGulati

Hi everyone, would attributing the discriminator to self like the self._generator be a potential solution here? It can maybe be optional with False as a default?

AndresAlgaba avatar Jul 27 '22 13:07 AndresAlgaba