Make `print_summary` by default as true
Hi, I've been using this module since a month and my experience with it has been largely pleasant. Thank you for your contribution 😄
One thing i observed was that i was always using print_summary as True, so i just feel, using True as default is more user friendly. In fact using False asserts the usage of needing a string instead of a printed report more than expecting a user to print it everytime...
Edit: I added a few more edits to i/o. I think the changes might have been too aggressive in design change, but we can test it out more and try to break it...
By default, we will have all inputs and outputs including model's i/o
>>> class myNN(nn.Module):
def __init__(self):
super().__init__()
self.model1 = nn.Sequential(
nn.Linear(100,200),
nn.ReLU(inplace=True),
nn.Linear(200,50),
nn.Softmax(-1)
)
self.model2 = nn.Sequential(
nn.Linear(10,20),
nn.ReLU(inplace=True),
nn.Linear(20,5),
nn.Softmax(-1)
)
def forward(self, x1, x2):
y1 = self.model1(x1)
y2 = self.model2(x2)
return y1,y2
>>> mynn = myNN()
>>> summary(mynn, torch.zeros(1,100), torch.zeros(1,10))
-------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
=========================================================================
Input [1, 100], [1, 10] -1
Linear-2 [1, 100] [1, 200] 20,200
ReLU-3 [1, 200] [1, 200] 0
Linear-4 [1, 200] [1, 50] 10,050
Softmax-5 [1, 50] [1, 50] 0
Linear-6 [1, 10] [1, 20] 220
ReLU-7 [1, 20] [1, 20] 0
Linear-8 [1, 20] [1, 5] 105
Softmax-9 [1, 5] [1, 5] 0
Output [1, 50], [1, 5] -1
=========================================================================
Total params: 30,577
Trainable params: 30,577
Non-trainable params: 0
-------------------------------------------------------------------------
Hi @sizhky! Thank you so much for your PR! After years using summary in keras, we needed a similar version in pytorch :smile:
I agree, print as default is better. In general, when we call this method we want to print indeed
When I created this module, I thought about printing both input/output shapes, but thinking in keras behavior I realized that, excluding first and last layers, information is duplicated because output from last layer is the input to the next (in general). Maybe an option to drop one of them when desired could embrace all programmers who [do/do not] want to see so much information.
On the other hand, version with both sounds good, specially in your example. I think for your example, maybe, an version showing parent layer could be better than print both input and output shapes. Keras has an option for that, but I haven't implemented that yet
Did you test your version with lib examples? I think version with both shapes in those examples will not have a good view due to shape layers. The "problem" showing both shapes get clear with them
What do you think about above points/questions?
I understand your points about keeping the table succinct. I want to add a couple of points -
- The output shape is only one more column and almost always it is important to know what goes in and what comes out. Even if it is redundant, it is vital. And in many cases (see CNN example below where) I/O don't match between tensors because of using
F.functions in forward. - Using parent will certainly help, but we must be aware that some modules can get tensors from multiple parents. The least we can do is show all the input tensor shapes and output tensor shapes so someone can at least match, which tensor is going where.
See the outputs for lib examples below. The output in Transformers is very messed up 😄
The "problem" showing both shapes get clear with them
I didn't undrestand
CNN
-----------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
=======================================================================
Input [1, 1, 28, 28] -1
Conv2d-2 [1, 1, 28, 28] [1, 10, 24, 24] 260
Conv2d-3 [1, 10, 12, 12] [1, 20, 8, 8] 5,020
Dropout2d-4 [1, 20, 8, 8] [1, 20, 8, 8] 0
Linear-5 [1, 320] [1, 50] 16,050
Linear-6 [1, 50] [1, 10] 510
Output [1, 10] -1
=======================================================================
Total params: 21,838
Trainable params: 21,838
Non-trainable params: 0
-----------------------------------------------------------------------
=========================== Hierarchical Summary ===========================
Net(
(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1)), 260 params
(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1)), 5,020 params
(conv2_drop): Dropout2d(p=0.5, inplace=False), 0 params
(fc1): Linear(in_features=320, out_features=50, bias=True), 16,050 params
(fc2): Linear(in_features=50, out_features=10, bias=True), 510 params
), 21,840 params
============================================================================
Transformer
-----------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===================================================================================
Input [1, 5], [1, 5] -1
Encoder-2 [1, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 17,332,224
Decoder-3 [1, 5], [1, 5], [1, 5, 512] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 22,060,544
Linear-4 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
===================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
-----------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==========================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Encoder-2 [1, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 17,332,224
Decoder-3 [1, 5], [1, 5], [1, 5, 512] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] 22,060,544
Linear-4 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
Batch size: 1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
================================ Hierarchical Summary ================================
Transformer(
(encoder): Encoder(
(src_emb): Embedding(6, 512), 3,072 params
(pos_emb): Embedding(6, 512), 3,072 params
(layers): ModuleList(
(0): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(1): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(2): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(3): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(4): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
(5): EncoderLayer(
(enc_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 2,887,680 params
), 17,326,080 params
), 17,332,224 params
(decoder): Decoder(
(tgt_emb): Embedding(7, 512), 3,584 params
(pos_emb): Embedding(6, 512), 3,072 params
(layers): ModuleList(
(0): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(1): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(2): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(3): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(4): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
(5): DecoderLayer(
(dec_self_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(dec_enc_attn): MultiHeadAttention(
(W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
(W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
), 787,968 params
(pos_ffn): PoswiseFeedForwardNet(
(conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
(conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
), 2,099,712 params
), 3,675,648 params
), 22,053,888 params
), 22,060,544 params
(projection): Linear(in_features=512, out_features=7, bias=False), 3,584 params
), 39,396,352 params
======================================================================================
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==========================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Embedding-2 [1, 5] [1, 5, 512] 3,072
Embedding-3 [1, 5] [1, 5, 512] 3,072
EncoderLayer-4 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-5 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-6 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-7 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-8 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
EncoderLayer-9 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 2,887,680
Embedding-10 [1, 5] [1, 5, 512] 3,584
Embedding-11 [1, 5] [1, 5, 512] 3,072
DecoderLayer-12 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-13 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-14 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-15 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-16 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
DecoderLayer-17 [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5] 3,675,648
Linear-18 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
==========================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Embedding-2 [1, 5] [1, 5, 512] 3,072
Embedding-3 [1, 5] [1, 5, 512] 3,072
Linear-4 [1, 5, 512] [1, 5, 512] 262,656
Linear-5 [1, 5, 512] [1, 5, 512] 262,656
Linear-6 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-7 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-8 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-9 [1, 5, 512] [1, 5, 512] 262,656
Linear-10 [1, 5, 512] [1, 5, 512] 262,656
Linear-11 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-12 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-13 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-14 [1, 5, 512] [1, 5, 512] 262,656
Linear-15 [1, 5, 512] [1, 5, 512] 262,656
Linear-16 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-17 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-18 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-19 [1, 5, 512] [1, 5, 512] 262,656
Linear-20 [1, 5, 512] [1, 5, 512] 262,656
Linear-21 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-22 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-23 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-24 [1, 5, 512] [1, 5, 512] 262,656
Linear-25 [1, 5, 512] [1, 5, 512] 262,656
Linear-26 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-27 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-28 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-29 [1, 5, 512] [1, 5, 512] 262,656
Linear-30 [1, 5, 512] [1, 5, 512] 262,656
Linear-31 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-32 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-33 [1, 2048, 5] [1, 512, 5] 1,049,088
Embedding-34 [1, 5] [1, 5, 512] 3,584
Embedding-35 [1, 5] [1, 5, 512] 3,072
Linear-36 [1, 5, 512] [1, 5, 512] 262,656
Linear-37 [1, 5, 512] [1, 5, 512] 262,656
Linear-38 [1, 5, 512] [1, 5, 512] 262,656
Linear-39 [1, 5, 512] [1, 5, 512] 262,656
Linear-40 [1, 5, 512] [1, 5, 512] 262,656
Linear-41 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-42 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-43 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-44 [1, 5, 512] [1, 5, 512] 262,656
Linear-45 [1, 5, 512] [1, 5, 512] 262,656
Linear-46 [1, 5, 512] [1, 5, 512] 262,656
Linear-47 [1, 5, 512] [1, 5, 512] 262,656
Linear-48 [1, 5, 512] [1, 5, 512] 262,656
Linear-49 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-50 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-51 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-52 [1, 5, 512] [1, 5, 512] 262,656
Linear-53 [1, 5, 512] [1, 5, 512] 262,656
Linear-54 [1, 5, 512] [1, 5, 512] 262,656
Linear-55 [1, 5, 512] [1, 5, 512] 262,656
Linear-56 [1, 5, 512] [1, 5, 512] 262,656
Linear-57 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-58 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-59 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-60 [1, 5, 512] [1, 5, 512] 262,656
Linear-61 [1, 5, 512] [1, 5, 512] 262,656
Linear-62 [1, 5, 512] [1, 5, 512] 262,656
Linear-63 [1, 5, 512] [1, 5, 512] 262,656
Linear-64 [1, 5, 512] [1, 5, 512] 262,656
Linear-65 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-66 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-67 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-68 [1, 5, 512] [1, 5, 512] 262,656
Linear-69 [1, 5, 512] [1, 5, 512] 262,656
Linear-70 [1, 5, 512] [1, 5, 512] 262,656
Linear-71 [1, 5, 512] [1, 5, 512] 262,656
Linear-72 [1, 5, 512] [1, 5, 512] 262,656
Linear-73 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-74 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-75 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-76 [1, 5, 512] [1, 5, 512] 262,656
Linear-77 [1, 5, 512] [1, 5, 512] 262,656
Linear-78 [1, 5, 512] [1, 5, 512] 262,656
Linear-79 [1, 5, 512] [1, 5, 512] 262,656
Linear-80 [1, 5, 512] [1, 5, 512] 262,656
Linear-81 [1, 5, 512] [1, 5, 512] 262,656
Conv1d-82 [1, 512, 5] [1, 2048, 5] 1,050,624
Conv1d-83 [1, 2048, 5] [1, 512, 5] 1,049,088
Linear-84 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Parent Layers Layer (type) Input Shape Output Shape Param #
======================================================================================================================================================================================================================================================================================================================================================================
Input [1, 5], [1, 5] -1
Transformer/Encoder Embedding-2 [1, 5] [1, 5, 512] 3,072
Transformer/Encoder Embedding-3 [1, 5] [1, 5, 512] 3,072
Transformer/Encoder/EncoderLayer MultiHeadAttention-4 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-5 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-6 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-7 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-8 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-9 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-10 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-11 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-12 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-13 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Encoder/EncoderLayer MultiHeadAttention-14 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Encoder/EncoderLayer PoswiseFeedForwardNet-15 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder Embedding-16 [1, 5] [1, 5, 512] 3,584
Transformer/Decoder Embedding-17 [1, 5] [1, 5, 512] 3,072
Transformer/Decoder/DecoderLayer MultiHeadAttention-18 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-19 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-20 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-21 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-22 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-23 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-24 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-25 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-26 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-27 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-28 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-29 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-30 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-31 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-32 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer/Decoder/DecoderLayer MultiHeadAttention-33 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer MultiHeadAttention-34 [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5] 787,968
Transformer/Decoder/DecoderLayer PoswiseFeedForwardNet-35 [1, 5, 512] [1, 5, 512] 2,099,712
Transformer Linear-36 [1, 5, 512] [1, 5, 7] 3,584
Output [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5] -1
======================================================================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
@sizhky table width is breaking in transformer example?
Im going to merge your PR when table width is validated and add a parameter to drop some column as optional parameter