addons
addons copied to clipboard
Why can we accumulate the gradients like torch? Just cumsum the training loss?
I found that when I used a very large model like robert-large,the implementation of gradients accumulation like this https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c may be very expensive for the gpu memory because I need to store an additional copy of the parameters of the entire roberta model in here " self.gradient_accumulation = [tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) for v in self.trainable_variables] "
The below is my implementation , just an imitation of the gradient accumulation of pytorch,but this implementation is not valid because I found that the loss just did not decrese . I think it is a problem of the design of tf.gradienttape,when one batch size is over ,the tape did not record the gradients in last batch size . How can I resolve this problem?
class Model2(tf.keras.Model): #cumsum the loss not the gradients
def __init__(self, n_gradients, *args, **kwargs):
super().__init__(*args, **kwargs)
self.n_gradients = float(n_gradients)
self.n_acum_step = 0
#self.gradient_accumulation = [tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) for v in self.trainable_variables]
self.total_loss= 0.00
#self.g=[]
def train_step(self, data):
self.n_acum_step+=1
x, y = data
# Gradient Tape
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
self.total_loss += loss ##just cumsum the loss here
if self.n_acum_step >= self.n_gradients:
gradients = tape.gradient(self.total_loss/self.n_gradients, self.trainable_variables) # remember to average the loss
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
self.total_loss=0.00 # reset the total loss
self.n_acum_step=0 ##reset the accumulation step
# update metrics
self.compiled_metrics.update_state(y, y_pred)
return {m.name: m.result() for m in self.metrics}
model = Model2(n_gradients=16,inputs...,outputs...)
model.fit....