addons icon indicating copy to clipboard operation
addons copied to clipboard

Why can we accumulate the gradients like torch? Just cumsum the training loss?

Open wangbingnan136 opened this issue 3 years ago • 0 comments

I found that when I used a very large model like robert-large,the implementation of gradients accumulation like this https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c may be very expensive for the gpu memory because I need to store an additional copy of the parameters of the entire roberta model in here " self.gradient_accumulation = [tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) for v in self.trainable_variables] "

The below is my implementation , just an imitation of the gradient accumulation of pytorch,but this implementation is not valid because I found that the loss just did not decrese . I think it is a problem of the design of tf.gradienttape,when one batch size is over ,the tape did not record the gradients in last batch size . How can I resolve this problem?

class Model2(tf.keras.Model): #cumsum the loss not the gradients
    def __init__(self, n_gradients, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n_gradients = float(n_gradients)
        self.n_acum_step = 0
        #self.gradient_accumulation = [tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) for v in self.trainable_variables]
        self.total_loss= 0.00
        #self.g=[]

    def train_step(self, data):
        self.n_acum_step+=1

        x, y = data
        # Gradient Tape
        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            self.total_loss += loss ##just cumsum the loss here
        if self.n_acum_step >= self.n_gradients:
          gradients = tape.gradient(self.total_loss/self.n_gradients, self.trainable_variables) # remember to average the loss
          self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
          self.total_loss=0.00 # reset the total loss
          self.n_acum_step=0 ##reset the accumulation step

          # update metrics
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

model = Model2(n_gradients=16,inputs...,outputs...)
model.fit....

wangbingnan136 avatar Apr 03 '22 23:04 wangbingnan136