[CLI]: log gradients broken for subclassed keras models
Describe the bug
For tensorflow keras, the WandbCallback cannot log gradients for a subclassed model. This due to the following line: https://github.com/wandb/client/blob/master/wandb/integration/keras/keras.py#L513, which accesses model.inputs which is None in the case of a subclassed model. Instead, probably the easiest thing is to accept an extra argument that provides input dimensions. That way, the callback can make it's own new tf.keras.Input layer.
Additional Files
No response
Environment
WandB version: 1.20.0
OS: all
Python version: all
Versions of relevant libraries: all
Additional Context
No response
Hi @lminer,
Thanks for bringing this to our notice. I'll let our engineers know about this, and we should have this resolved. Would you happen to have a minimal script to reproduce this issue?
​Hi @lminer,
We wanted to follow up with you regarding your support request as we have not heard back from you. Please let us know if we can be of further assistance or if your issue has been resolved.
Best, Weights & Biases ​
WandB Internal User commented: ramit-wandb commented: ​Hi @lminer,
We wanted to follow up with you regarding your support request as we have not heard back from you. Please let us know if we can be of further assistance or if your issue has been resolved.
Best, Weights & Biases ​
​Hi @lminer, since we have not heard back from you we are going to close this request. If you would like to re-open the conversation, please let us know!
@ramit-wandb I don't have the time to produce reproducible code, but you don't need it in this case. Just look at the code itself:
inputs = self.model.inputs
outputs = self.model(inputs)
grad_acc_model = tf.keras.models.Model(inputs, outputs)
grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())
Subclassed tensorflow models do not possess an inputs instance variable, therefore, by definition of the api, this code will fail.