GradCache
GradCache copied to clipboard
Tiny numerical differences, Weight updates not perfectly matching
Hi, thanks for this amazing library.
I saw one tiny issue which is that the final weights of the model is different when training with multiple sub_batches per step vs 1 big_batch per step. I'm not sure if such numerical differences are expected when using this library.
I'm using clip with contrastive loss, here's my quick experimental code that I made sure to run multiple times and it results in exactly the same output each time: (note: I'm using CLIP with 151 million parameters and a dataset of only 32 samples for experimental purposes)
model1 = train_clip_normally(epochs=1, batch_size=16)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 0.3163
Above we see that training for two sub_batches of 8 vs training for 1 batch of 16 gives a tiny different in the norm of the weights of the two models.
model1 = train_clip_normally(epochs=1, batch_size=16)
model2 = train_clip_gradcache(epochs=1, batch_size=16, batches_per_backward=1)
print(calc_model_param_difference(model1, model2)) # RETURN: 0
Above we see that the models are equivalent when making gradcache perform a backward every batch
model1 = train_clip_gradcache(epochs=1, batch_size=4, batches_per_backward=4)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 0.3105
Above we see the difference still exists for two different gradcache batch sizes
However this library is still working amazingly as if I compare it with normally training with whatever maximum batch size fits in my GPU, I get a huge difference (which is expected and exactly why I need this library) as seen below
model1 = train_clip_normally(epochs=1, batch_size=8)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 363.2708
Below is my code in case the problem is with it:
def train_clip_normally(epochs, batch_size):
dl = torch.utils.data.DataLoader(d, batch_size=batch_size, shuffle=False)
model = MyCLIPModel("openai/clip-vit-base-patch32").to('cuda:1')
optimizer = torch.optim.Adam(model.parameters())
for e in range(epochs):
cliptrain.train_epoch(model, optimizer, processor, dl)
return model
def train_clip_gradcache(epochs, batch_size, batches_per_backward):
dl = torch.utils.data.DataLoader(d, batch_size=batch_size, shuffle=False)
model = MyCLIPModel("openai/clip-vit-base-patch32").to('cuda:1')
optimizer = torch.optim.Adam(model.parameters())
for e in range(epochs):
cliptrain.ClipModelClone.grad_cache_train(model, optimizer, processor, dl, batches_per_backward=batches_per_backward)
return model
def calc_model_param_difference(model1, model2):
diff = 0
for p1, p2 in zip(model1.parameters(), model2.parameters()):
diff += torch.norm(p1.data - p2.data)
return diff
from grad_cache.functional import cached, cat_input_tensor
def grad_cache_train(model, optimizer, processor, dataloader, batches_per_backward):
cache_x = []
cache_y = []
closures_x = []
closures_y = []
for step, sub_batch in enumerate(dataloader):
inputs = processor(text=sub_batch['text'], return_tensors="pt", padding=True, truncation=True)
inputs['input_ids'] = inputs['input_ids'].to(model.device)
inputs['attention_mask'] = inputs['attention_mask'].to(model.device)
inputs['pixel_values'] = sub_batch['image'].to(model.device)
inputs['return_loss'] = True
print('step', step)
rx, cx = call_text_model(model, inputs)
ry, cy = call_vision_model(model, inputs)
cache_x.append(rx)
cache_y.append(ry)
closures_x.append(cx)
closures_y.append(cy)
if (step + 1) % batches_per_backward == 0:
print('BACKWARD!')
loss = grad_cat_loss(cache_x, cache_y, model.logit_scale)
loss.backward()
for f, r in zip(closures_x, cache_x):
f(r)
for f, r in zip(closures_y, cache_y):
f(r)
cache_x = []
cache_y = []
closures_x = []
closures_y = []
optimizer.step()
optimizer.zero_grad()
@cat_input_tensor
def grad_cat_loss(text_embeds, image_embeds, logit_scale):
sim = torch.matmul(text_embeds, image_embeds.t()) * logit_scale.exp()
return clip_loss(sim)
@cached
def call_text_model(model, input):
return model.forward_text(**input)
@cached
def call_vision_model(model, input):
return model.forward_visual(**input)
def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))