softlearning icon indicating copy to clipboard operation
softlearning copied to clipboard

Possibility to show structure of the model

Open kapsl opened this issue 5 years ago • 1 comments

Hi, is there an easy way of showing the structure of the keras model? Like normally with .summary() after compiling it?

kapsl avatar Apr 06 '20 08:04 kapsl

Good question! Here's how you can currently get the summaries of the keras models:

ipdb> policy.shift_and_scale_model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
observations (InputLayer)    [(None, 3)]               0
_________________________________________________________________
lambda_4 (Lambda)            (None, 3)                 0
_________________________________________________________________
feedforward_model (Sequentia (None, 2)                 67330
_________________________________________________________________
lambda_6 (Lambda)            [(None, 1), (None, 1)]    0
_________________________________________________________________
lambda_7 (Lambda)            (None, 1)                 0
=================================================================
Total params: 67,330
Trainable params: 67,330
Non-trainable params: 0
_________________________________________________________________
ipdb> Qs[0].model.summary()
Model: "feedforward_Q"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
0/observations (InputLayer)     [(None, 3)]          0
__________________________________________________________________________________________________
1 (InputLayer)                  [(None, 1)]          0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 4)            0           0/observations[0][0]
                                                                 1[0][0]
__________________________________________________________________________________________________
feedforward_Q (Sequential)      (None, 1)            67329       lambda[0][0]
==================================================================================================
Total params: 67,329
Trainable params: 67,329
Non-trainable params: 0
__________________________________________________________________________________________________

It's a bit clumsy, and ideally we'd have a .summary() methods implemented in the policy and value function classes. I'll try to implement those at some point. Let's keep this issue open for now as a feature request.

hartikainen avatar Apr 06 '20 12:04 hartikainen