flax icon indicating copy to clipboard operation
flax copied to clipboard

Visualize the model/params structure

Open Conchylicultor opened this issue 3 years ago • 6 comments

Problem you have encountered:

Currently it is not clear how to inspect a model structure. For example to see whether some imported model uses dropout, batch-norm,... Or to find out which weight we want to freeze vs fine tune. Similarly the params returned by model.init is difficult to inspect, as extracting the structure require writing some custom code.

What you expected to happen:

It would be nice if __repr__ was displaying some human-readable structure, like in pytorch

  1. For params, the shape/dtype, rather than 1000+ lines of weights values:
>>> params
FrozenDict({
    params: {
        Encoder_0: {
            Conv_0: {
                kernel: float32[3, 3, 1, 16],
                bias: float32[16],
            },
            Conv_1: {
                kernel: float32[3, 3, 16, 32],
                bias: float32[32],
            },
            Conv_2: {
                kernel: float32[7, 7, 32, 64],
                bias: float32[64],
            },
        },
        Decoder_0: {
            ConvTranspose_0: {
                kernel: float32[7, 7, 64, 32],
                bias: float32[32],
            },
            ConvTranspose_1: {
                kernel: float32[3, 3, 32, 16],
                bias: float32[16],
            },
            ConvTranspose_2: {
                kernel: float32[3, 3, 16, 1],
                bias: float32[1],
            },
        },
    },
})
  1. For model, the modules names & submodules:

Torch for example display the model structure quite clearly, so it is easy to view which operations are used:

import torchvision.models.resnet as resnet

model = resnet.resnet18()
print(model)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)

Conchylicultor avatar Feb 11 '21 15:02 Conchylicultor

I think it makes sense for the Module.__repr__ to print the variable structure (name+shape+dtype). This shouldn't be very difficult to add. The same can be done for FrozenDict.__repr__. We could consider a threshold like onp.size(x) > 5 to switch between printing the shape and the full content.

jheek avatar Feb 12 '21 10:02 jheek

Haiku has excellent model summary functionality (documentation here). I think equivalents to haiku.experimental.tabulate and haiku.experimental.eval_summary would be very helpful.

n2cholas avatar Feb 22 '21 17:02 n2cholas

You can use parameter_overview in clu for params visualization:

Commands
import jax
import numpy as np
from flax import linen as nn
from clu import parameter_overview

class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

key = jax.random.PRNGKey(0)
variables = CNN().init(key, np.random.randn(1, 32, 32, 3))
print(parameter_overview.get_parameter_overview(variables))
+-----------------------+----------------+-----------+-----------+--------+
| Name                  | Shape          | Size      | Mean      | Std    |
+-----------------------+----------------+-----------+-----------+--------+
| params/Conv_0/bias    | (32,)          | 32        | 0.0       | 0.0    |
| params/Conv_0/kernel  | (3, 3, 3, 32)  | 864       | 0.00277   | 0.2    |
| params/Conv_1/bias    | (64,)          | 64        | 0.0       | 0.0    |
| params/Conv_1/kernel  | (3, 3, 32, 64) | 18,432    | 0.000202  | 0.0591 |
| params/Dense_0/bias   | (256,)         | 256       | 0.0       | 0.0    |
| params/Dense_0/kernel | (4096, 256)    | 1,048,576 | -1.54e-05 | 0.0156 |
| params/Dense_1/bias   | (10,)          | 10        | 0.0       | 0.0    |
| params/Dense_1/kernel | (256, 10)      | 2,560     | -0.00159  | 0.0622 |
+-----------------------+----------------+-----------+-----------+--------+
Total: 1,070,794

myagues avatar Feb 26 '21 15:02 myagues

I think something native equivalent to what @n2cholas showed from Haiku provides would be very nice!

cgarciae avatar Jul 31 '21 14:07 cgarciae

Assigning this to @cgarciae since he is working on this currently.

marcvanzee avatar Apr 27 '22 20:04 marcvanzee

+1

quanvuong avatar Jul 08 '22 22:07 quanvuong

Closing this since @cgarciae has written Module.tabulate, please check it out!

marcvanzee avatar Sep 06 '22 13:09 marcvanzee

is there an way to visualize the model like http://alexlenail.me/NN-SVG/AlexNet.html image

1kaiser avatar Jan 12 '23 10:01 1kaiser