Access to underlying model weights of CTGAN and TVAE for manipulation
Problem Description
Ability to access underlying deep learning model written in pytorch. It would be nice to have API function call to read and write weights of network (e.g. generator and discriminator network in case of CTGAN).
Expected behavior
API access to variables in CTGAN package through SDV package.
Additional context
Having such feature would allow to train federated learning models. After training, deep learning networks on different local clients with
different local datasets in SDV using fit method the model weights can be read (through this new API) and sent for aggregation at global server and then global server sends the aggregated model which is then loaded on the selected clients at each round. Following that the local training starts again and the process repeats. Finally, after training for desired number of rounds, SDV's sample method can be used to generate synthetic datasets with underlying aggregated deep learning model which captures behaviors from different dateset holding local clients.
Thank you for filing & describing your use case @MakGulati. We'll keep this issue open and update it whenever we make progress.
An unsupported & hacky workaround you can try in the meantime: You can access the generator model using the model._generator. This will get you PyTorch Module object. From there on out, you'd have to refer to the PyTorch user guides & API to extract any desired parameters.
I would also need access to discriminator. As it is a local variable now, I cannot access it outside the class. @npatki With change in the dimensions of discriminator, it is also reload discriminator.
Hi everyone, would attributing the discriminator to self like the self._generator be a potential solution here? It can maybe be optional with False as a default?