PlotNeuralNet
PlotNeuralNet copied to clipboard
Feature Request: Automated Creation Based On PyTorch/Keras Neural Network Structures
This feature request is about an automated creation of visualizations from PyTorch/Keras sequential architectures.
Current State
Neural network structures need to be explicitly created for visualization via PlotNeuralNet with the supported, mainly residual convolutional architecture parts, functions.
Proposed State
Neural network structures are visualized based on a given Python class structure, like PyTorch or Keras modules.
Possible Implementation
Proposing a module parsing Python class that retrieves a "net list"
Initial implementation should focus on the simple feed forward network structures. This can be extended based on more complex net lists.
Unfortunately, structures which are not defined in a feed-forward layout or otherwise very structured layout but created via calls of layers in a forward pass function/method are way harder to implement. Asking for support, here.
There is already raw code for PyTorch torch.nn.Sequence
which I clean up for initial PR from my fork. Please wait a bit and then have a look into the proposed PR with possible extensions, afterwards.
This will initially be an easy example with explicit fully connected layers.
PR #126 focuses on PyTorch.
TODOs for subsequent PRs
- Variable generation with respect to PyTorch layer and Activation Functions~~
- Keras support