burn icon indicating copy to clipboard operation
burn copied to clipboard

Print model structure like with PyTorch

Open antimora opened this issue 1 year ago • 3 comments

Feature description

Want to see a models structure at a glance like when you print a pytorch model:

import whisper
model = whisper.load_model("tiny")
print(model)

Result:

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-3): 4 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=384, out_features=384, bias=True)
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(in_features=384, out_features=384, bias=True)
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1536, out_features=384, bias=True)
        )
        (mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): TextDecoder(
    (token_embedding): Embedding(51865, 384)
    (blocks): ModuleList(
      (0-3): 4 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=384, out_features=384, bias=True)
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(in_features=384, out_features=384, bias=True)
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (cross_attn): MultiHeadAttention(
          (query): Linear(in_features=384, out_features=384, bias=True)
          (key): Linear(in_features=384, out_features=384, bias=False)
          (value): Linear(in_features=384, out_features=384, bias=True)
          (out): Linear(in_features=384, out_features=384, bias=True)
        )
        (cross_attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1536, out_features=384, bias=True)
        )
        (mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
  )
)

Feature motivation

Feature to see at a glance instead of reviewing code.

antimora avatar Feb 23 '24 19:02 antimora

Did some digging around. This one looks like a pretty easy starting point, so Ill give it a shot.

Is this something that would be preferred as the implementation of display::fmt? Right now display only writes the name and number of parameters, which isn't hugely useful. Alternatively perhaps some new function, e.g. tree(), on the module trait would probably work just fine.

McArthur-Alford avatar Mar 15 '24 15:03 McArthur-Alford

I prefer if we implement display::fmt. BTW, you may have to look into burn-derive crate to get attribute names and tree structure.

antimora avatar Mar 15 '24 20:03 antimora

I have started working on this issue and I have a good design solution that's flexible and robust.

antimora avatar May 13 '24 04:05 antimora