burn
burn copied to clipboard
Print model structure like with PyTorch
Feature description
Want to see a models structure at a glance like when you print a pytorch model:
import whisper
model = whisper.load_model("tiny")
print(model)
Result:
Whisper(
(encoder): AudioEncoder(
(conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
(conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
(blocks): ModuleList(
(0-3): 4 x ResidualAttentionBlock(
(attn): MultiHeadAttention(
(query): Linear(in_features=384, out_features=384, bias=True)
(key): Linear(in_features=384, out_features=384, bias=False)
(value): Linear(in_features=384, out_features=384, bias=True)
(out): Linear(in_features=384, out_features=384, bias=True)
)
(attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=1536, out_features=384, bias=True)
)
(mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)
)
(ln_post): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)
(decoder): TextDecoder(
(token_embedding): Embedding(51865, 384)
(blocks): ModuleList(
(0-3): 4 x ResidualAttentionBlock(
(attn): MultiHeadAttention(
(query): Linear(in_features=384, out_features=384, bias=True)
(key): Linear(in_features=384, out_features=384, bias=False)
(value): Linear(in_features=384, out_features=384, bias=True)
(out): Linear(in_features=384, out_features=384, bias=True)
)
(attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(cross_attn): MultiHeadAttention(
(query): Linear(in_features=384, out_features=384, bias=True)
(key): Linear(in_features=384, out_features=384, bias=False)
(value): Linear(in_features=384, out_features=384, bias=True)
(out): Linear(in_features=384, out_features=384, bias=True)
)
(cross_attn_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=384, out_features=1536, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=1536, out_features=384, bias=True)
)
(mlp_ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)
)
(ln): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)
)
Feature motivation
Feature to see at a glance instead of reviewing code.
Did some digging around. This one looks like a pretty easy starting point, so Ill give it a shot.
Is this something that would be preferred as the implementation of display::fmt? Right now display only writes the name and number of parameters, which isn't hugely useful.
Alternatively perhaps some new function, e.g. tree(), on the module trait would probably work just fine.
I prefer if we implement display::fmt. BTW, you may have to look into burn-derive crate to get attribute names and tree structure.
I have started working on this issue and I have a good design solution that's flexible and robust.