powernorm icon indicating copy to clipboard operation
powernorm copied to clipboard

A few questions regarding fairseq/modules/norms/mask_powernorm.py

Open congwang093 opened this issue 1 year ago • 18 comments

Hi, first of all thank you for your work. I've been spending some time trying to understand what is happening in this script fairseq/modules/norms/mask_powernorm.py but I've been having some trouble. can you please answer these questions?

  1. was 'GroupScaling1D' (starting at line 17) specific for the data or model architecture that was used for the experiments, but not necessarily a part of the general method for PowerNorm? based on my understanding, the input is supposed to be shaped (T tokens or instances, B batches, C channels). it seems to be a modified layer norm where each value in the input tensor is divided by the mean of the squared values across (each Groups of 4 channels for each Batch for each Token). I believe this was not mentioned in the paper.

  2. on these few lines here in the forward function of PowerFunction:

if current_iter < warmup_iters: running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)

running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))

since 'var' is (1,C,1,1), var.mean(dim=0,keepdim=True) is the same tensor as 'var'. was this intentional, or perhaps an artifact from an earlier version of the code? also did you mean to put an else statement here for 'running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))'?

thank you, i'd very much appreciate your time

congwang093 avatar Jul 29 '23 04:07 congwang093

Hello,

I do have some more questions continuing on what was mentioned before

In the NormSelect function https://github.com/sIncerass/powernorm/blob/9ea6226a3203d5d6fcee07a5c6dec38ec6bc5e9f/fairseq/modules/norm_select.py#L12-L19

for batch norm we are using MaskSyncBatchNorm : version of Sync Batch Norm which is used because of multi-gpu training but for power norm I didn't see any SyncPowerNorm. Is it because PowerNorm doesn't need synchronized version ? As I understand that we need sync version if we are using batch statistics (this is why we don't have sync layer norm).

Also in appendix of the paper it is mentioned that for "PN-V", a synchronized version is used. If possible can you release that part of the code as well ?

lumliolum avatar Jul 31 '24 17:07 lumliolum

The discussions around groupscaling are given here : #9, #8

lumliolum avatar Jul 31 '24 17:07 lumliolum

thx for the clarification. not that I'm the original author, but I'm trying to implement this right now instead, if I'm successful I will share it here.

Ice-Citron avatar Aug 25 '24 07:08 Ice-Citron

@lumliolum I have came up with this for now, I also checked it against the original mark_powernorm.py as much as I could. Can you help me check it too?

I ran the command "torchrun --standalone --nproc_per_node=4 test.py" to run this code using a setup with pytorch installed and 4x GPU, V100s in my case. I will double verify that teh code works soon. But, using the help of Claude 3.5 and GPT-4, I managed to come up with this makeshift solution. What I did was:

  1. Implemented a sync power norm version, and pasted the original version as well in the file
  2. Then, I basically pass in the same data into them, and initialised the same model (made sure its the exact same through randn and also torch.set_seed(42), and also torch.set_seed(42) each time before I initialised a random NN model which one is using syncPowerNorm (which is ran by all 4 GPUs) and the same for the original "MaskPowerNorm")
  3. Then I ran the file, and I basically also placed controlled_print everywhere to monitor their respective forward_pass and backward_pass every time, and then calculated the mean absolute difference between the .STD and .mean of the gradients of the 2 different NNs (1 using the original powernorm, and the other the version I made which has syncing)

The result is this, the mean absolute difference is 0, and the 2 different networks, 1 using the original and the other using syncPowerNorm which I implemented seems to have the exact same results, as seen in image below:

Screenshot 2024-08-28 at 2 28 37 AM

Success? I'm not sure yet. I will try and double check and get back to you.

# torchrun --standalone --nproc_per_node=4 test.py

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

# Global variable to control printing
PRINT_ALL_RANKS = True

def controlled_print(message):
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    if PRINT_ALL_RANKS or local_rank == 0:
        print(f"Rank {local_rank}: {message}")

class GroupScaling1D(nn.Module):
    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f'eps={self.eps}, group={self.group_num}'

    def forward(self, input):
        T, B, C = input.shape[0], input.shape[1], input.shape[2]
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = torch.repeat_interleave(torch.mean(gn_input * gn_input, dim=3, keepdim=True),
            repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
        return input / torch.sqrt(moment2 + self.eps)

class PowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x):
        # Original PowerFunction forward code here
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        rmax = 1
        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)
        if current_iter <= warmup_iters:
            z = x /(var + eps).sqrt()
        else:
            z = x /(running_phi + eps).sqrt()
            
        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Original Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Original Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Original Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Original Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Original Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Original Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z)
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g 
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"
        
        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class MaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
        self.afwd = alpha_fwd
        self.abkw = alpha_bkw
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)
        
        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
        
        T, B, C = input.shape
        input = self.gp(input)

        # construct the mask_input, size to be (BxL) x C: L is the real length here
        if pad_mask is None:
            mask_input = input.clone()
        else:
            # Transpose the bn_mask (B x T -> T x B)
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = PowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps, \
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input)
            
        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        # Reshape it.
        if shaped_input:
            output = output.squeeze(0)

        return output


class SyncPowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x, process_group, world_size):
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        ctx.process_group = process_group
        ctx.world_size = world_size

        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)

        # Synchronize var across GPUs
        if process_group is not None:
            dist.all_reduce(var, op=dist.ReduceOp.SUM, group=process_group)
            var /= world_size

        if current_iter <= warmup_iters:
            z = x / (var + eps).sqrt()
        else:
            z = x / (running_phi + eps).sqrt()

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Sync Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Sync Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Sync Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Sync Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Sync Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Sync Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z)
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g 
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"

        if is_sync:
            process_group = ctx.process_group
            world_size = ctx.world_size
            dist.all_reduce(grad_weight, op=dist.ReduceOp.SUM, group=process_group)
            dist.all_reduce(grad_bias, op=dist.ReduceOp.SUM, group=process_group)
            grad_weight /= world_size
            grad_bias /= world_size
        
        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class SyncMaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1, process_group=None):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.process_group = process_group
        self.world_size = dist.get_world_size(process_group) if process_group else 1

        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))

        self.afwd = alpha_fwd
        self.abkw = alpha_bkw

        self.eps = eps
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)
        
        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
        
        T, B, C = input.shape
        input = self.gp(input)

        if pad_mask is None:
            mask_input = input.clone()
        else:
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = SyncPowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps,
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input,
                        self.process_group, self.world_size)
        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        if shaped_input:
            output = output.squeeze(0)

        return output

class TestModel(nn.Module):
    def __init__(self, norm_layer):
        super(TestModel, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.norm = norm_layer(64)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.norm(self.conv(x)))

def run_test(local_rank, world_size):
    dist.init_process_group(backend="nccl", init_method='env://', world_size=world_size, rank=local_rank)

    # Set same seed for all processes
    torch.manual_seed(42)
    
    # Generate same data for all processes
    batch_size = 32
    torch.manual_seed(42)
    data = torch.randn(batch_size, 3, 64, 64).cuda(local_rank)
    
    # Ensure all processes have the same data
    dist.broadcast(data, src=0)

    if local_rank == 0:
        # Run original MaskPowerNorm on single GPU
        torch.manual_seed(42)
        model_original = TestModel(lambda num_features: MaskPowerNorm(num_features)).cuda(local_rank)
        out_original = model_original(data)
        loss_original = out_original.sum()
        loss_original.backward()
        
        controlled_print("Running original PowerNorm on single GPU")
        controlled_print(f"Original output mean: {out_original.mean().item()}")

    # Run SyncMaskPowerNorm on all GPUs
    torch.manual_seed(42)
    model_sync = TestModel(lambda num_features: SyncMaskPowerNorm(num_features, process_group=dist.group.WORLD)).cuda(local_rank)
    model_sync = DDP(model_sync, device_ids=[local_rank])
    
    controlled_print("Running SyncPowerNorm")
    out_sync = model_sync(data)
    controlled_print(f"Sync output mean on rank {local_rank}: {out_sync.mean().item()}")

    loss_sync = out_sync.sum()
    loss_sync.backward()

    if local_rank == 0:
        for (name_o, param_o), (name_s, param_s) in zip(model_original.named_parameters(), model_sync.named_parameters()):
            if param_o.grad is not None and param_s.grad is not None:
                grad_diff = (param_o.grad - param_s.grad).abs().mean().item()
                controlled_print(f"Gradient difference for {name_o}:")
                controlled_print(f"  Original - mean: {param_o.grad.mean().item()}, std: {param_o.grad.std().item()}")
                controlled_print(f"  Sync     - mean: {param_s.grad.mean().item()}, std: {param_s.grad.std().item()}")
                controlled_print(f"  Absolute difference: {grad_diff}")


    dist.barrier()
    dist.destroy_process_group()

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    controlled_print(f"Running on rank {local_rank} of {world_size}")
    run_test(local_rank, world_size)

if __name__ == "__main__":
    main()

Ice-Citron avatar Aug 27 '24 18:08 Ice-Citron

Hey @Ice-Citron

Thanks for sharing the code for SyncPowerNorm. Currently I don't have acess to the server (with multiple GPU's) to run your code. I will get it back after some weeks I guess (If possible, I will update you then)

I don't have much knowledge on how distributed system and code works (apologies in advance if I wrote something wrong) but anyway I had a look. Some queries I have is

  • In the forward of SyncPowerNorm, I saw that you are using allreduce on variance and as running_phi uses only var, it will be same across all devices. In backward you are doing all reduce on grad_weight, grad_bias but not on gx. Can I know why ?
  • Continuation of the above : running_ema which is the running statistic in backward is there is no all reduce on that ? Like at each point running_ema will be different across device. I feel that for running_ema also we should use all reduce?

lumliolum avatar Aug 28 '24 06:08 lumliolum

@lumliolum Hi. Sorry that I have been late to this. Have been very busy with other school stuff and projects. But I have 2 hours now, and will try and answer your questions, by trying to verify and figure out whats going on with my code, to make sure that it works.

Ice-Citron avatar Aug 31 '24 15:08 Ice-Citron

@lumliolum Give me a bit longer. Just managed to understood how the maths etc. works. Only realised just now that this normalisation layer is meant for ViTs instead of language transformers. lol

Ice-Citron avatar Sep 01 '24 10:09 Ice-Citron

@lumliolum Actually yeah, good point. When I rushed out the code initially. I was just trying to get something running as fast as possible. Didn't realised about that, I will try and look at whether is all_reduce neccesary for running_ema.

Whilst, for the reason why gx isn't synchronised, this is what GPT-4 said, which makes a lot of sense.


The reason gx (the gradient of the loss with respect to the input ( x )) does not require synchronization through all_reduce or similar operations in a distributed training context like DDP (Distributed Data Parallel) primarily relates to how and where it is used in neural network training.

Understanding the Use of gx in Backpropagation

  1. Local Relevance:

    • gx represents the gradient of the network's loss with respect to its inputs at a particular layer. This gradient is used locally by each GPU to compute gradients for the weights and biases that are directly connected to these inputs within the same layer or previous layers.
    • Since each GPU processes a different subset of the data (mini-batch), gx computed on one GPU is specifically relevant to the forward pass computations and the immediate backpropagation calculations on that same GPU. There is no need to share or synchronize gx across GPUs because it does not directly influence the parameters being updated in other GPUs.
  2. Parameter Updates vs. Input Gradients:

    • In DDP, the primary concern is to synchronize the gradients of the parameters (weights and biases) across all GPUs before they are updated. This synchronization ensures that each GPU updates its model parameters based on the complete information from the entire distributed dataset, thus maintaining consistency across all replicas of the model.
    • gx, on the other hand, is used to propagate error gradients back through the network to update parameters local to each layer. Each instance of backpropagation uses gx from the subsequent layer to calculate the gradients for parameters in the current layer, but these gradients do not need to be shared across GPUs because each GPU independently handles a portion of the data.
  3. No Global Impact:

    • gx affects only the local computation within a GPU and is used to calculate gradients of parameters that will subsequently be synchronized. Since the parameters are the elements that ultimately need to be consistent across all GPUs, only their gradients require synchronization.
    • If gx were part of a computation that influenced model parameters globally (across all GPUs), then it would indeed need to be synchronized. However, in standard neural network operations, this is not the case.

Practical Implications

In a practical scenario, consider a multi-layer neural network training across multiple GPUs:

  • Each GPU computes forward pass and backward pass locally for the batch of data it processes.
  • During the backward pass, gx for each layer is computed based on the output gradient from the subsequent layer (grad_output), which is also local to the GPU.
  • The calculated gradients of the weights and biases (which impact the model globally) are then synchronized across GPUs using operations like all_reduce to ensure every GPU has the same updated view before the next iteration of training.

Conclusion

gx not requiring synchronization across GPUs allows each part of the model to be trained as efficiently and quickly as possible on its batch of data, relying on the synchronization of parameter gradients to ensure overall model consistency. This approach optimizes network bandwidth and computation by avoiding unnecessary data transfers and maintains the independence of local computations, which is a key advantage in distributed training setups.


question: is it wrong is gx is synchronised?


Synchronizing gx (the gradient of the loss with respect to the input at a particular layer) across GPUs in a distributed training setup isn't typically necessary, but doing so wouldn't be inherently "wrong"—rather, it would be inefficient and largely redundant. Let's explore why this is generally avoided and what implications it might have if done:

1. Redundancy and Efficiency

  • Redundancy: In most training scenarios, gx is used locally within the same GPU to compute gradients of parameters (weights and biases) that are earlier in the network relative to the current layer. Since gx is only used to compute updates that are local to specific layers, sharing this across GPUs does not contribute additional beneficial information that would influence the global model training.
  • Efficiency: Transferring gx across GPUs consumes bandwidth and computational resources, which could otherwise be used for synchronizing parameters that directly affect the consistency and performance of the model across all GPUs. The overhead introduced by unnecessarily synchronizing gx could slow down the training process without providing tangible benefits.

2. Practical Impact of Synchronization

  • No Benefit: Synchronizing gx across GPUs does not benefit the training process because each GPU is already handling its portion of the batch data. The local gx on each GPU is sufficient for the backpropagation through layers handled by that GPU.
  • Potential Confusion: Synchronizing gx might introduce complexity into the training logic without clear justification. It could potentially lead to confusion about the role and usage of gradients in distributed training setups, where the focus should ideally be on synchronizing parameters that affect the model globally.

3. Conceptual Clarity

  • Layer-Specific Calculations: Backpropagation is inherently a layer-specific calculation where each layer computes its output gradients based on inputs from the previous layer (or subsequent layer in the case of backpropagation). This sequence ensures that the error signal is propagated correctly backward through the network, adjusting each layer's parameters based on the local gradient computations.
  • Isolation of Layers: Each layer's parameters are updated based on gradients computed from its outputs. Synchronizing gx would imply an unnecessary linkage between the gradient flows of different parts of the network across different GPUs, which is not how neural network training is typically conceptualized or implemented.

4. Theoretical Considerations

  • Local vs. Global Gradients: While parameter gradients need to reflect global data characteristics (hence the synchronization), gx reflects how local computations at each layer should adjust to minimize the loss. Its value inherently depends on local data and layer-specific dynamics, which are contextually isolated from similar computations on other GPUs.

Conclusion

While not "wrong," synchronizing gx is not standard practice and does not align with the principles of efficient, effective distributed training. Keeping gx local optimizes computational resources and aligns with the foundational concepts of neural network training, where each part of the network learns based on the data it processes directly. The key to efficient distributed training is to minimize overhead while ensuring global consistency where it impacts the model—primarily at the level of model parameters, not intermediate gradients like gx.

Ice-Citron avatar Sep 01 '24 10:09 Ice-Citron

@lumliolum I'm looking into running_ema now. Have a feeling this needs to be synchronised, but let's see.

Ice-Citron avatar Sep 01 '24 10:09 Ice-Citron

@lumliolum Yep your correct. Here's the final code instead. Please do let me know if there's any logic error still. I had tried to double check with Claude and GPT4 already and doesn't seem like any error. But let me know if there's anything wrong. Thx

# torchrun --standalone --nproc_per_node=4 test.py

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

# Global variable to control printing
PRINT_ALL_RANKS = True

def controlled_print(message):
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    if PRINT_ALL_RANKS or local_rank == 0:
        print(f"Rank {local_rank}: {message}")

class GroupScaling1D(nn.Module):
    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f'eps={self.eps}, group={self.group_num}'

    def forward(self, input):
        T, B, C = input.shape[0], input.shape[1], input.shape[2]
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = torch.repeat_interleave(torch.mean(gn_input * gn_input, dim=3, keepdim=True),
            repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
        return input / torch.sqrt(moment2 + self.eps)

class PowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x):
        # Original PowerFunction forward code here
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        rmax = 1
        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)
        if current_iter <= warmup_iters:
            z = x /(var + eps).sqrt() # dividing by sqrt(varience + eps), which is same as dividing by standard deviation
        else:
            z = x /(running_phi + eps).sqrt() # same thing as above, but using running stats instead
            
        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter) # cumulative moving average
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True)) # exponential moving average
        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Original Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Original Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Original Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Original Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Original Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Original Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z) # approx function seems to just be using CTX stored tensors
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g # REFER TO NOTES REGARDING BACKPROP DERIVATIVE EQUATION
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"
        
        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class MaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
        self.afwd = alpha_fwd
        self.abkw = alpha_bkw
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)
        
        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
        
        T, B, C = input.shape
        input = self.gp(input)

        # construct the mask_input, size to be (BxL) x C: L is the real length here
        if pad_mask is None:
            mask_input = input.clone()
        else:
            # Transpose the bn_mask (B x T -> T x B)
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = PowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps, \
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input)
            
        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        # Reshape it.
        if shaped_input:
            output = output.squeeze(0)

        return output


class SyncPowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x, process_group, world_size):
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        ctx.process_group = process_group
        ctx.world_size = world_size

        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)

        # Synchronize var across GPUs
        if process_group is not None:
            dist.all_reduce(var, op=dist.ReduceOp.AVG, group=process_group) # no need to divide, because already averaging

        if current_iter <= warmup_iters:
            z = x / (var + eps).sqrt()
        else:
            z = x / (running_phi + eps).sqrt()

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))

        # Synchronize running_phi across all processes
        if process_group is not None:
            torch.distributed.all_reduce(running_phi, op=torch.distributed.ReduceOp.AVG, group=process_group)

        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Sync Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Sync Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Sync Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Sync Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Sync Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Sync Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z)
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        if ctx.process_group is not None:
            dist.all_reduce(ema_gz, op=dist.ReduceOp.AVG, group=ctx.process_group)

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g 
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"

        if is_sync:
            process_group = ctx.process_group
            world_size = ctx.world_size
            dist.all_reduce(grad_weight, op=dist.ReduceOp.AVG, group=process_group)
            dist.all_reduce(grad_bias, op=dist.ReduceOp.AVG, group=process_group)
        
        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class SyncMaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1, process_group=None):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.process_group = process_group
        self.world_size = dist.get_world_size(process_group) if process_group else 1

        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))

        # Synchronize buffers
        dist.all_reduce(self.running_phi, op=dist.ReduceOp.AVG, group=self.process_group)
        dist.all_reduce(self.ema_gz, op=dist.ReduceOp.AVG, group=self.process_group)
        # dist.all_reduce(self.iters, op=dist.ReduceOp.MAX, group=self.process_group)

        self.afwd = alpha_fwd
        self.abkw = alpha_bkw

        self.eps = eps
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)
        
        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
        
        T, B, C = input.shape
        input = self.gp(input)

        if pad_mask is None:
            mask_input = input.clone()
        else:
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1) # maybe consider syncing this, but unlikely
            output = SyncPowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps,
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input,
                        self.process_group, self.world_size)
        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        if shaped_input:
            output = output.squeeze(0)

        return output

class TestModel(nn.Module):
    def __init__(self, norm_layer):
        super(TestModel, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.norm = norm_layer(64)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.norm(self.conv(x)))

def run_test(local_rank, world_size):
    dist.init_process_group(backend="nccl", init_method='env://', world_size=world_size, rank=local_rank)

    # Set same seed for all processes
    torch.manual_seed(42)
    
    # Generate same data for all processes
    batch_size = 32
    torch.manual_seed(42)
    data = torch.randn(batch_size, 3, 64, 64).cuda(local_rank)
    
    # Ensure all processes have the same data
    dist.broadcast(data, src=0)

    if local_rank == 0:
        # Run original MaskPowerNorm on single GPU
        torch.manual_seed(42)
        model_original = TestModel(lambda num_features: MaskPowerNorm(num_features)).cuda(local_rank)
        out_original = model_original(data)
        loss_original = out_original.sum()
        loss_original.backward()
        
        controlled_print("Running original PowerNorm on single GPU")
        controlled_print(f"Original output mean: {out_original.mean().item()}")

    # Run SyncMaskPowerNorm on all GPUs
    torch.manual_seed(42)
    model_sync = TestModel(lambda num_features: SyncMaskPowerNorm(num_features, process_group=dist.group.WORLD)).cuda(local_rank)
    model_sync = DDP(model_sync, device_ids=[local_rank])
    
    controlled_print("Running SyncPowerNorm")
    out_sync = model_sync(data)
    controlled_print(f"Sync output mean on rank {local_rank}: {out_sync.mean().item()}")

    loss_sync = out_sync.sum()
    loss_sync.backward()

    if local_rank == 0:
        for (name_o, param_o), (name_s, param_s) in zip(model_original.named_parameters(), model_sync.named_parameters()):
            if param_o.grad is not None and param_s.grad is not None:
                grad_diff = (param_o.grad - param_s.grad).abs().mean().item()
                controlled_print(f"Gradient difference for {name_o}:")
                controlled_print(f"  Original - mean: {param_o.grad.mean().item()}, std: {param_o.grad.std().item()}")
                controlled_print(f"  Sync     - mean: {param_s.grad.mean().item()}, std: {param_s.grad.std().item()}")
                controlled_print(f"  Absolute difference: {grad_diff}")


    dist.barrier()
    dist.destroy_process_group()

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    controlled_print(f"Running on rank {local_rank} of {world_size}")
    run_test(local_rank, world_size)

if __name__ == "__main__":
    main()

Ice-Citron avatar Sep 01 '24 11:09 Ice-Citron

Only realised just now that this normalisation layer is meant for ViTs instead of language transformers

But in the paper, they were using this norm layer for machine translation and language modeling task.

Here's the final code instead

You didn't paste total final code. Only pasted the forward of nn.Module so can you post autograd.Function

lumliolum avatar Sep 01 '24 12:09 lumliolum

@lumliolum ah sorry, there you go. Please try and check my full code instead. I recommend Tensordock if you wanna get started. Something like 4x Nvidia A4000 already works, which costs 0.5 USD per hour.

But in the paper, they were using this norm layer for machine translation and language modeling task.

Ah I see. I am just trying to make sure that I'm able to convert my .pth transformer model to huggingface, then I will personally start training my models with a node of 8x H100, soon.

Ice-Citron avatar Sep 01 '24 12:09 Ice-Citron

@lumliolum Hi. just wanna check in. Any issues hence forth?

Ice-Citron avatar Sep 07 '24 12:09 Ice-Citron