flax
flax copied to clipboard
Visualize the model/params structure
Problem you have encountered:
Currently it is not clear how to inspect a model structure. For example to see whether some imported model uses dropout, batch-norm,... Or to find out which weight we want to freeze vs fine tune.
Similarly the params
returned by model.init
is difficult to inspect, as extracting the structure require writing some custom code.
What you expected to happen:
It would be nice if __repr__
was displaying some human-readable structure, like in pytorch
- For params, the shape/dtype, rather than 1000+ lines of weights values:
>>> params
FrozenDict({
params: {
Encoder_0: {
Conv_0: {
kernel: float32[3, 3, 1, 16],
bias: float32[16],
},
Conv_1: {
kernel: float32[3, 3, 16, 32],
bias: float32[32],
},
Conv_2: {
kernel: float32[7, 7, 32, 64],
bias: float32[64],
},
},
Decoder_0: {
ConvTranspose_0: {
kernel: float32[7, 7, 64, 32],
bias: float32[32],
},
ConvTranspose_1: {
kernel: float32[3, 3, 32, 16],
bias: float32[16],
},
ConvTranspose_2: {
kernel: float32[3, 3, 16, 1],
bias: float32[1],
},
},
},
})
- For model, the modules names & submodules:
Torch for example display the model structure quite clearly, so it is easy to view which operations are used:
import torchvision.models.resnet as resnet
model = resnet.resnet18()
print(model)
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
I think it makes sense for the Module.__repr__
to print the variable structure (name+shape+dtype). This shouldn't be very difficult to add. The same can be done for FrozenDict.__repr__
. We could consider a threshold like onp.size(x) > 5
to switch between printing the shape and the full content.
Haiku has excellent model summary functionality (documentation here). I think equivalents to haiku.experimental.tabulate
and haiku.experimental.eval_summary
would be very helpful.
You can use parameter_overview
in clu
for params visualization:
Commands
import jax
import numpy as np
from flax import linen as nn
from clu import parameter_overview
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
key = jax.random.PRNGKey(0)
variables = CNN().init(key, np.random.randn(1, 32, 32, 3))
print(parameter_overview.get_parameter_overview(variables))
+-----------------------+----------------+-----------+-----------+--------+
| Name | Shape | Size | Mean | Std |
+-----------------------+----------------+-----------+-----------+--------+
| params/Conv_0/bias | (32,) | 32 | 0.0 | 0.0 |
| params/Conv_0/kernel | (3, 3, 3, 32) | 864 | 0.00277 | 0.2 |
| params/Conv_1/bias | (64,) | 64 | 0.0 | 0.0 |
| params/Conv_1/kernel | (3, 3, 32, 64) | 18,432 | 0.000202 | 0.0591 |
| params/Dense_0/bias | (256,) | 256 | 0.0 | 0.0 |
| params/Dense_0/kernel | (4096, 256) | 1,048,576 | -1.54e-05 | 0.0156 |
| params/Dense_1/bias | (10,) | 10 | 0.0 | 0.0 |
| params/Dense_1/kernel | (256, 10) | 2,560 | -0.00159 | 0.0622 |
+-----------------------+----------------+-----------+-----------+--------+
Total: 1,070,794
I think something native equivalent to what @n2cholas showed from Haiku provides would be very nice!
Assigning this to @cgarciae since he is working on this currently.
+1
Closing this since @cgarciae has written Module.tabulate, please check it out!
is there an way to visualize the model like http://alexlenail.me/NN-SVG/AlexNet.html