wandb icon indicating copy to clipboard operation
wandb copied to clipboard

[CLI]: log gradients broken for subclassed keras models

Open lminer opened this issue 3 years ago • 5 comments

Describe the bug

For tensorflow keras, the WandbCallback cannot log gradients for a subclassed model. This due to the following line: https://github.com/wandb/client/blob/master/wandb/integration/keras/keras.py#L513, which accesses model.inputs which is None in the case of a subclassed model. Instead, probably the easiest thing is to accept an extra argument that provides input dimensions. That way, the callback can make it's own new tf.keras.Input layer.

Additional Files

No response

Environment

WandB version: 1.20.0

OS: all

Python version: all

Versions of relevant libraries: all

Additional Context

No response

lminer avatar Jun 30 '22 13:06 lminer

Hi @lminer,

Thanks for bringing this to our notice. I'll let our engineers know about this, and we should have this resolved. Would you happen to have a minimal script to reproduce this issue?

ramit-wandb avatar Jun 30 '22 21:06 ramit-wandb

​Hi @lminer,

We wanted to follow up with you regarding your support request as we have not heard back from you. Please let us know if we can be of further assistance or if your issue has been resolved.

Best, Weights & Biases ​

ramit-wandb avatar Jul 18 '22 17:07 ramit-wandb

WandB Internal User commented: ramit-wandb commented: ​Hi @lminer,

We wanted to follow up with you regarding your support request as we have not heard back from you. Please let us know if we can be of further assistance or if your issue has been resolved.

Best, Weights & Biases ​

exalate-issue-sync[bot] avatar Jul 21 '22 19:07 exalate-issue-sync[bot]

​Hi @lminer, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!

ramit-wandb avatar Jul 21 '22 20:07 ramit-wandb

@ramit-wandb I don't have the time to produce reproducible code, but you don't need it in this case. Just look at the code itself:

inputs = self.model.inputs
outputs = self.model(inputs)
grad_acc_model = tf.keras.models.Model(inputs, outputs)
grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())

Subclassed tensorflow models do not possess an inputs instance variable, therefore, by definition of the api, this code will fail.

lminer avatar Jul 22 '22 10:07 lminer