handson-ml3 icon indicating copy to clipboard operation
handson-ml3 copied to clipboard

Custom Optimizers not working [BUG]

Open abhiTokopedia opened this issue 1 year ago • 2 comments

To Reproduce run following

class MyMomentumOptimizer(tf.keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.001, momentum=0.9, name="MyMomentumOptimizer", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__('MyMomentumOptimizer', **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
        self._set_hyper("decay", self._initial_decay) # 
        self._set_hyper("momentum", momentum)
    
    def _create_slots(self, var_list):
        """For each model variable, create the optimizer variable associated with it.
        TensorFlow calls these optimizer variables "slots".
        For momentum optimization, we need one momentum slot per model variable.
        """
        for var in var_list:
            self.add_slot(var, "momentum")

    @tf.function
    def _resource_apply_dense(self, grad, var):
        """Update the slots and perform one optimization step for one model variable
        """
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype) # handle learning rate decay
        momentum_var = self.get_slot(var, "momentum")
        momentum_hyper = self._get_hyper("momentum", var_dtype)
        momentum_var.assign(momentum_var * momentum_hyper - (1. - momentum_hyper)* grad)
        var.assign_add(momentum_var * lr_t)

    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        }

Following Exception Occured

AttributeError: 'MyMomentumOptimizer' object has no attribute '_set_hyper' Versions (please complete the following information):

  • OS: [Debian 4.19.269-1]
  • Python: [3.7]
  • TensorFlow: [2.11]

Additional context After taking a look at base class I couldn't find _set_hyper in base class

abhiTokopedia avatar Mar 22 '23 13:03 abhiTokopedia

[Update] - This code piece is not intented to work with tensorflow 2.11

abhiTokopedia avatar Mar 22 '23 13:03 abhiTokopedia

No more update? No way to make it work with tf 2.11?

lxlofpku avatar Jul 02 '23 11:07 lxlofpku