PlotNeuralNet
PlotNeuralNet copied to clipboard
Feature: Automated Creation Based on Example for PyTorch Linear Modules with ReLU Activations
Work in progress
The PR addresses #124
Automated generation from PyTorch module class ``torch.nn.modulechild, leveraging
torchinfoarchitecture summary interface, comparable with TensorFlow/Keras
summary` method.
This is created from the following code:
# Define example module
import torch as th
class MLP(th.nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.net = th.nn.Sequential(
th.nn.Linear(2, 16),
th.nn.ReLU(),
th.nn.Linear(16, 16),
th.nn.ReLU(),
th.nn.Linear(16, 1)
)
def forward(self, x):
x = self.net(x)
return x
# Parse the example module
from pycore.torchparse import TorchArchParser
from pycore.tikzeng import to_generate
mlp = MLP()
parser = TorchArchParser(torch_module=mlp, input_size=(1,2))
arch = parser.get_arch()
to_generate(arch, pathname="./test_torch_mlp.tex")
- [x] Initial PyTorch support with Linear and ReLU in Sequential module structure
- [x] Added interface for custom fill colors, hack of
Conv
for fully connected layer until a specialized function is provided.
TODOs for subsequent PRs
- Variable generation with respect to PyTorch layer and Activation Functions~~
- Keras support
Addressed #124 with respect to PyTorch
Please feel free to feed back and contribute to this feature. Especially regarding the subsequent Keras support.
Ready as initial functionality for PyTorch automated generation support. Please review and merge if deemed OK.
hey guys, any plan to support conv layer ?
Hey @space192 I am currently occupied and hindered to push the PR further but we happily accept your extension to CNNs. You can create a PR regarding this to that branch of my fork - so it shows up here.