GradCache icon indicating copy to clipboard operation
GradCache copied to clipboard

Tiny numerical differences, Weight updates not perfectly matching

Open Ar-Kareem opened this issue 1 year ago • 2 comments

Hi, thanks for this amazing library.

I saw one tiny issue which is that the final weights of the model is different when training with multiple sub_batches per step vs 1 big_batch per step. I'm not sure if such numerical differences are expected when using this library.

I'm using clip with contrastive loss, here's my quick experimental code that I made sure to run multiple times and it results in exactly the same output each time: (note: I'm using CLIP with 151 million parameters and a dataset of only 32 samples for experimental purposes)

model1 = train_clip_normally(epochs=1, batch_size=16)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 0.3163

Above we see that training for two sub_batches of 8 vs training for 1 batch of 16 gives a tiny different in the norm of the weights of the two models.

model1 = train_clip_normally(epochs=1, batch_size=16)
model2 = train_clip_gradcache(epochs=1, batch_size=16, batches_per_backward=1)
print(calc_model_param_difference(model1, model2)) # RETURN: 0

Above we see that the models are equivalent when making gradcache perform a backward every batch

model1 = train_clip_gradcache(epochs=1, batch_size=4, batches_per_backward=4)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 0.3105

Above we see the difference still exists for two different gradcache batch sizes

However this library is still working amazingly as if I compare it with normally training with whatever maximum batch size fits in my GPU, I get a huge difference (which is expected and exactly why I need this library) as seen below

model1 = train_clip_normally(epochs=1, batch_size=8)
model2 = train_clip_gradcache(epochs=1, batch_size=8, batches_per_backward=2)
print(calc_model_param_difference(model1, model2)) # RETURN: 363.2708

Below is my code in case the problem is with it:

def train_clip_normally(epochs, batch_size):
    dl = torch.utils.data.DataLoader(d, batch_size=batch_size, shuffle=False)
    model = MyCLIPModel("openai/clip-vit-base-patch32").to('cuda:1')
    optimizer = torch.optim.Adam(model.parameters())
    for e in range(epochs):
        cliptrain.train_epoch(model, optimizer, processor, dl)
    return model

def train_clip_gradcache(epochs, batch_size, batches_per_backward):
    dl = torch.utils.data.DataLoader(d, batch_size=batch_size, shuffle=False)
    model = MyCLIPModel("openai/clip-vit-base-patch32").to('cuda:1')
    optimizer = torch.optim.Adam(model.parameters())
    for e in range(epochs):
        cliptrain.ClipModelClone.grad_cache_train(model, optimizer, processor, dl, batches_per_backward=batches_per_backward)
    return model

def calc_model_param_difference(model1, model2):
    diff = 0
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        diff += torch.norm(p1.data - p2.data)
    return diff

from grad_cache.functional import cached, cat_input_tensor

def grad_cache_train(model, optimizer, processor, dataloader, batches_per_backward):
    cache_x = []
    cache_y = []
    closures_x = []
    closures_y = []

    for step, sub_batch in enumerate(dataloader):  
        inputs = processor(text=sub_batch['text'], return_tensors="pt", padding=True, truncation=True)
        inputs['input_ids'] = inputs['input_ids'].to(model.device)
        inputs['attention_mask'] = inputs['attention_mask'].to(model.device)
        inputs['pixel_values'] = sub_batch['image'].to(model.device)
        inputs['return_loss'] = True

        print('step', step)
        rx, cx = call_text_model(model, inputs)
        ry, cy = call_vision_model(model, inputs)
        
        cache_x.append(rx)
        cache_y.append(ry)
        closures_x.append(cx)
        closures_y.append(cy)
        
        if (step + 1) % batches_per_backward == 0:
            print('BACKWARD!')
            loss = grad_cat_loss(cache_x, cache_y, model.logit_scale)
            loss.backward()
            
            for f, r in zip(closures_x, cache_x):
                f(r)
            for f, r in zip(closures_y, cache_y):
                f(r)

            cache_x = []
            cache_y = []
            closures_x = []
            closures_y = []
        
            optimizer.step()
            optimizer.zero_grad()

@cat_input_tensor
def grad_cat_loss(text_embeds, image_embeds, logit_scale):
    sim = torch.matmul(text_embeds, image_embeds.t()) * logit_scale.exp()
    return clip_loss(sim)

@cached
def  call_text_model(model, input):
    return model.forward_text(**input)

@cached
def  call_vision_model(model, input):
    return model.forward_visual(**input)

def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
    caption_loss = contrastive_loss(similarity)
    image_loss = contrastive_loss(similarity.t())
    return (caption_loss + image_loss) / 2.0

def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
    return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))

Ar-Kareem avatar Nov 28 '22 23:11 Ar-Kareem