powernorm
powernorm copied to clipboard
A few questions regarding fairseq/modules/norms/mask_powernorm.py
Hi, first of all thank you for your work. I've been spending some time trying to understand what is happening in this script fairseq/modules/norms/mask_powernorm.py but I've been having some trouble. can you please answer these questions?
-
was 'GroupScaling1D' (starting at line 17) specific for the data or model architecture that was used for the experiments, but not necessarily a part of the general method for PowerNorm? based on my understanding, the input is supposed to be shaped (T tokens or instances, B batches, C channels). it seems to be a modified layer norm where each value in the input tensor is divided by the mean of the squared values across (each Groups of 4 channels for each Batch for each Token). I believe this was not mentioned in the paper.
-
on these few lines here in the forward function of PowerFunction:
if current_iter < warmup_iters: running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
since 'var' is (1,C,1,1), var.mean(dim=0,keepdim=True) is the same tensor as 'var'. was this intentional, or perhaps an artifact from an earlier version of the code? also did you mean to put an else statement here for 'running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))'?
thank you, i'd very much appreciate your time
Hello,
I do have some more questions continuing on what was mentioned before
In the NormSelect
function https://github.com/sIncerass/powernorm/blob/9ea6226a3203d5d6fcee07a5c6dec38ec6bc5e9f/fairseq/modules/norm_select.py#L12-L19
for batch norm we are using MaskSyncBatchNorm
: version of Sync Batch Norm which is used because of multi-gpu training but for power norm I didn't see any SyncPowerNorm
. Is it because PowerNorm
doesn't need synchronized version ? As I understand that we need sync version if we are using batch statistics (this is why we don't have sync layer norm).
Also in appendix of the paper it is mentioned that for "PN-V", a synchronized version is used. If possible can you release that part of the code as well ?
The discussions around groupscaling are given here : #9, #8
thx for the clarification. not that I'm the original author, but I'm trying to implement this right now instead, if I'm successful I will share it here.
@lumliolum I have came up with this for now, I also checked it against the original mark_powernorm.py as much as I could. Can you help me check it too?
I ran the command "torchrun --standalone --nproc_per_node=4 test.py" to run this code using a setup with pytorch installed and 4x GPU, V100s in my case. I will double verify that teh code works soon. But, using the help of Claude 3.5 and GPT-4, I managed to come up with this makeshift solution. What I did was:
- Implemented a sync power norm version, and pasted the original version as well in the file
- Then, I basically pass in the same data into them, and initialised the same model (made sure its the exact same through randn and also torch.set_seed(42), and also torch.set_seed(42) each time before I initialised a random NN model which one is using syncPowerNorm (which is ran by all 4 GPUs) and the same for the original "MaskPowerNorm")
- Then I ran the file, and I basically also placed controlled_print everywhere to monitor their respective forward_pass and backward_pass every time, and then calculated the mean absolute difference between the .STD and .mean of the gradients of the 2 different NNs (1 using the original powernorm, and the other the version I made which has syncing)
The result is this, the mean absolute difference is 0, and the 2 different networks, 1 using the original and the other using syncPowerNorm which I implemented seems to have the exact same results, as seen in image below:
Success? I'm not sure yet. I will try and double check and get back to you.
# torchrun --standalone --nproc_per_node=4 test.py
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
# Global variable to control printing
PRINT_ALL_RANKS = True
def controlled_print(message):
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if PRINT_ALL_RANKS or local_rank == 0:
print(f"Rank {local_rank}: {message}")
class GroupScaling1D(nn.Module):
def __init__(self, eps=1e-5, group_num=4):
super(GroupScaling1D, self).__init__()
self.eps = eps
self.group_num = group_num
def extra_repr(self):
return f'eps={self.eps}, group={self.group_num}'
def forward(self, input):
T, B, C = input.shape[0], input.shape[1], input.shape[2]
Cg = C // self.group_num
gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
moment2 = torch.repeat_interleave(torch.mean(gn_input * gn_input, dim=3, keepdim=True),
repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
return input / torch.sqrt(moment2 + self.eps)
class PowerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
debug, warmup_iters, current_iter, mask_x):
# Original PowerFunction forward code here
ctx.eps = eps
ctx.debug = debug
current_iter = current_iter.item()
ctx.current_iter = current_iter
ctx.warmup_iters = warmup_iters
ctx.abkw = abkw
rmax = 1
N, C, H, W = x.size()
x2 = (mask_x * mask_x).mean(dim=0)
var = x2.reshape(1, C, 1, 1)
if current_iter <= warmup_iters:
z = x /(var + eps).sqrt()
else:
z = x /(running_phi + eps).sqrt()
y = z
ctx.save_for_backward(z, var, weight, ema_gz)
if current_iter < warmup_iters:
running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)
controlled_print(f"Original Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
controlled_print(f"Original Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
controlled_print(f"Original Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
controlled_print(f"Original Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
controlled_print(f"Original Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
controlled_print(f"Original Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
debug = ctx.debug
current_iter = ctx.current_iter
warmup_iters = ctx.warmup_iters
abkw = ctx.abkw
N, C, H, W = grad_output.size()
z, var, weight, ema_gz = ctx.saved_tensors
y = z
g = grad_output * weight.reshape(1, C, 1, 1)
g = g * 1
gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)
approx_grad_g = (g - (1 - abkw) * ema_gz * z)
ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))
gx = 1. / torch.sqrt(var + eps) * approx_grad_g
grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)
is_sync = hasattr(ctx, 'process_group')
prefix = "Sync" if is_sync else "Original"
controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")
return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None
class MaskPowerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
affine=True, warmup_iters=10000, group_num=1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
self.afwd = alpha_fwd
self.abkw = alpha_bkw
self.debug = False
self.warmup_iters = warmup_iters
self.gp = GroupScaling1D(group_num=group_num)
self.group_num = group_num
def extra_repr(self):
return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)
def forward(self, input, pad_mask=None, is_encoder=False):
shaped_input = (len(input.shape) == 2)
if shaped_input:
input = input.unsqueeze(0)
if input.dim() == 4: # N, C, H, W
N, C, H, W = input.shape
input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
T, B, C = input.shape
input = self.gp(input)
# construct the mask_input, size to be (BxL) x C: L is the real length here
if pad_mask is None:
mask_input = input.clone()
else:
# Transpose the bn_mask (B x T -> T x B)
bn_mask = ~pad_mask
bn_mask = bn_mask.transpose(0, 1)
if pad_mask is not None:
pad_size = (~bn_mask).sum()
mask_input = input[bn_mask, :]
else:
mask_input = input.clone()
mask_input = mask_input.reshape(-1, self.num_features)
input = input.permute(1, 2, 0).contiguous()
input_shape = input.size()
input = input.reshape(input.size(0), self.num_features, -1)
input = input.unsqueeze(-1)
if self.training:
self.iters.copy_(self.iters + 1)
output = PowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps, \
self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input)
else:
N, C, H, W = input.size()
var = self.running_phi
output = input / (var + self.eps).sqrt()
output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)
output = output.reshape(input_shape)
output = output.permute(2, 0, 1).contiguous()
# Reshape it.
if shaped_input:
output = output.squeeze(0)
return output
class SyncPowerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
debug, warmup_iters, current_iter, mask_x, process_group, world_size):
ctx.eps = eps
ctx.debug = debug
current_iter = current_iter.item()
ctx.current_iter = current_iter
ctx.warmup_iters = warmup_iters
ctx.abkw = abkw
ctx.process_group = process_group
ctx.world_size = world_size
N, C, H, W = x.size()
x2 = (mask_x * mask_x).mean(dim=0)
var = x2.reshape(1, C, 1, 1)
# Synchronize var across GPUs
if process_group is not None:
dist.all_reduce(var, op=dist.ReduceOp.SUM, group=process_group)
var /= world_size
if current_iter <= warmup_iters:
z = x / (var + eps).sqrt()
else:
z = x / (running_phi + eps).sqrt()
y = z
ctx.save_for_backward(z, var, weight, ema_gz)
if current_iter < warmup_iters:
running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)
controlled_print(f"Sync Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
controlled_print(f"Sync Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
controlled_print(f"Sync Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
controlled_print(f"Sync Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
controlled_print(f"Sync Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
controlled_print(f"Sync Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
debug = ctx.debug
current_iter = ctx.current_iter
warmup_iters = ctx.warmup_iters
abkw = ctx.abkw
N, C, H, W = grad_output.size()
z, var, weight, ema_gz = ctx.saved_tensors
y = z
g = grad_output * weight.reshape(1, C, 1, 1)
g = g * 1
gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)
approx_grad_g = (g - (1 - abkw) * ema_gz * z)
ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))
gx = 1. / torch.sqrt(var + eps) * approx_grad_g
grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)
is_sync = hasattr(ctx, 'process_group')
prefix = "Sync" if is_sync else "Original"
if is_sync:
process_group = ctx.process_group
world_size = ctx.world_size
dist.all_reduce(grad_weight, op=dist.ReduceOp.SUM, group=process_group)
dist.all_reduce(grad_bias, op=dist.ReduceOp.SUM, group=process_group)
grad_weight /= world_size
grad_bias /= world_size
controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")
return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None
class SyncMaskPowerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
affine=True, warmup_iters=10000, group_num=1, process_group=None):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.process_group = process_group
self.world_size = dist.get_world_size(process_group) if process_group else 1
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
self.afwd = alpha_fwd
self.abkw = alpha_bkw
self.eps = eps
self.debug = False
self.warmup_iters = warmup_iters
self.gp = GroupScaling1D(group_num=group_num)
self.group_num = group_num
def extra_repr(self):
return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)
def forward(self, input, pad_mask=None, is_encoder=False):
shaped_input = (len(input.shape) == 2)
if shaped_input:
input = input.unsqueeze(0)
if input.dim() == 4: # N, C, H, W
N, C, H, W = input.shape
input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
T, B, C = input.shape
input = self.gp(input)
if pad_mask is None:
mask_input = input.clone()
else:
bn_mask = ~pad_mask
bn_mask = bn_mask.transpose(0, 1)
if pad_mask is not None:
pad_size = (~bn_mask).sum()
mask_input = input[bn_mask, :]
else:
mask_input = input.clone()
mask_input = mask_input.reshape(-1, self.num_features)
input = input.permute(1, 2, 0).contiguous()
input_shape = input.size()
input = input.reshape(input.size(0), self.num_features, -1)
input = input.unsqueeze(-1)
if self.training:
self.iters.copy_(self.iters + 1)
output = SyncPowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps,
self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input,
self.process_group, self.world_size)
else:
N, C, H, W = input.size()
var = self.running_phi
output = input / (var + self.eps).sqrt()
output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)
output = output.reshape(input_shape)
output = output.permute(2, 0, 1).contiguous()
if shaped_input:
output = output.squeeze(0)
return output
class TestModel(nn.Module):
def __init__(self, norm_layer):
super(TestModel, self).__init__()
self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.norm = norm_layer(64)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.norm(self.conv(x)))
def run_test(local_rank, world_size):
dist.init_process_group(backend="nccl", init_method='env://', world_size=world_size, rank=local_rank)
# Set same seed for all processes
torch.manual_seed(42)
# Generate same data for all processes
batch_size = 32
torch.manual_seed(42)
data = torch.randn(batch_size, 3, 64, 64).cuda(local_rank)
# Ensure all processes have the same data
dist.broadcast(data, src=0)
if local_rank == 0:
# Run original MaskPowerNorm on single GPU
torch.manual_seed(42)
model_original = TestModel(lambda num_features: MaskPowerNorm(num_features)).cuda(local_rank)
out_original = model_original(data)
loss_original = out_original.sum()
loss_original.backward()
controlled_print("Running original PowerNorm on single GPU")
controlled_print(f"Original output mean: {out_original.mean().item()}")
# Run SyncMaskPowerNorm on all GPUs
torch.manual_seed(42)
model_sync = TestModel(lambda num_features: SyncMaskPowerNorm(num_features, process_group=dist.group.WORLD)).cuda(local_rank)
model_sync = DDP(model_sync, device_ids=[local_rank])
controlled_print("Running SyncPowerNorm")
out_sync = model_sync(data)
controlled_print(f"Sync output mean on rank {local_rank}: {out_sync.mean().item()}")
loss_sync = out_sync.sum()
loss_sync.backward()
if local_rank == 0:
for (name_o, param_o), (name_s, param_s) in zip(model_original.named_parameters(), model_sync.named_parameters()):
if param_o.grad is not None and param_s.grad is not None:
grad_diff = (param_o.grad - param_s.grad).abs().mean().item()
controlled_print(f"Gradient difference for {name_o}:")
controlled_print(f" Original - mean: {param_o.grad.mean().item()}, std: {param_o.grad.std().item()}")
controlled_print(f" Sync - mean: {param_s.grad.mean().item()}, std: {param_s.grad.std().item()}")
controlled_print(f" Absolute difference: {grad_diff}")
dist.barrier()
dist.destroy_process_group()
def main():
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
controlled_print(f"Running on rank {local_rank} of {world_size}")
run_test(local_rank, world_size)
if __name__ == "__main__":
main()
Hey @Ice-Citron
Thanks for sharing the code for SyncPowerNorm
. Currently I don't have acess to the server (with multiple GPU's) to run your code. I will get it back after some weeks I guess (If possible, I will update you then)
I don't have much knowledge on how distributed system and code works (apologies in advance if I wrote something wrong) but anyway I had a look. Some queries I have is
- In the forward of
SyncPowerNorm
, I saw that you are usingallreduce
on variance and asrunning_phi
uses onlyvar
, it will be same across all devices. In backward you are doing all reduce ongrad_weight, grad_bias
but not ongx
. Can I know why ? - Continuation of the above :
running_ema
which is the running statistic in backward is there is no all reduce on that ? Like at each pointrunning_ema
will be different across device. I feel that forrunning_ema
also we should use all reduce?
@lumliolum Hi. Sorry that I have been late to this. Have been very busy with other school stuff and projects. But I have 2 hours now, and will try and answer your questions, by trying to verify and figure out whats going on with my code, to make sure that it works.
@lumliolum Give me a bit longer. Just managed to understood how the maths etc. works. Only realised just now that this normalisation layer is meant for ViTs instead of language transformers. lol
@lumliolum Actually yeah, good point. When I rushed out the code initially. I was just trying to get something running as fast as possible. Didn't realised about that, I will try and look at whether is all_reduce neccesary for running_ema.
Whilst, for the reason why gx
isn't synchronised, this is what GPT-4 said, which makes a lot of sense.
The reason gx
(the gradient of the loss with respect to the input ( x )) does not require synchronization through all_reduce
or similar operations in a distributed training context like DDP (Distributed Data Parallel) primarily relates to how and where it is used in neural network training.
Understanding the Use of gx
in Backpropagation
-
Local Relevance:
-
gx
represents the gradient of the network's loss with respect to its inputs at a particular layer. This gradient is used locally by each GPU to compute gradients for the weights and biases that are directly connected to these inputs within the same layer or previous layers. - Since each GPU processes a different subset of the data (mini-batch),
gx
computed on one GPU is specifically relevant to the forward pass computations and the immediate backpropagation calculations on that same GPU. There is no need to share or synchronizegx
across GPUs because it does not directly influence the parameters being updated in other GPUs.
-
-
Parameter Updates vs. Input Gradients:
- In DDP, the primary concern is to synchronize the gradients of the parameters (weights and biases) across all GPUs before they are updated. This synchronization ensures that each GPU updates its model parameters based on the complete information from the entire distributed dataset, thus maintaining consistency across all replicas of the model.
-
gx
, on the other hand, is used to propagate error gradients back through the network to update parameters local to each layer. Each instance of backpropagation usesgx
from the subsequent layer to calculate the gradients for parameters in the current layer, but these gradients do not need to be shared across GPUs because each GPU independently handles a portion of the data.
-
No Global Impact:
-
gx
affects only the local computation within a GPU and is used to calculate gradients of parameters that will subsequently be synchronized. Since the parameters are the elements that ultimately need to be consistent across all GPUs, only their gradients require synchronization. - If
gx
were part of a computation that influenced model parameters globally (across all GPUs), then it would indeed need to be synchronized. However, in standard neural network operations, this is not the case.
-
Practical Implications
In a practical scenario, consider a multi-layer neural network training across multiple GPUs:
- Each GPU computes forward pass and backward pass locally for the batch of data it processes.
- During the backward pass,
gx
for each layer is computed based on the output gradient from the subsequent layer (grad_output
), which is also local to the GPU. - The calculated gradients of the weights and biases (which impact the model globally) are then synchronized across GPUs using operations like
all_reduce
to ensure every GPU has the same updated view before the next iteration of training.
Conclusion
gx
not requiring synchronization across GPUs allows each part of the model to be trained as efficiently and quickly as possible on its batch of data, relying on the synchronization of parameter gradients to ensure overall model consistency. This approach optimizes network bandwidth and computation by avoiding unnecessary data transfers and maintains the independence of local computations, which is a key advantage in distributed training setups.
question: is it wrong is gx is synchronised?
Synchronizing gx
(the gradient of the loss with respect to the input at a particular layer) across GPUs in a distributed training setup isn't typically necessary, but doing so wouldn't be inherently "wrong"—rather, it would be inefficient and largely redundant. Let's explore why this is generally avoided and what implications it might have if done:
1. Redundancy and Efficiency
-
Redundancy: In most training scenarios,
gx
is used locally within the same GPU to compute gradients of parameters (weights and biases) that are earlier in the network relative to the current layer. Sincegx
is only used to compute updates that are local to specific layers, sharing this across GPUs does not contribute additional beneficial information that would influence the global model training. -
Efficiency: Transferring
gx
across GPUs consumes bandwidth and computational resources, which could otherwise be used for synchronizing parameters that directly affect the consistency and performance of the model across all GPUs. The overhead introduced by unnecessarily synchronizinggx
could slow down the training process without providing tangible benefits.
2. Practical Impact of Synchronization
-
No Benefit: Synchronizing
gx
across GPUs does not benefit the training process because each GPU is already handling its portion of the batch data. The localgx
on each GPU is sufficient for the backpropagation through layers handled by that GPU. -
Potential Confusion: Synchronizing
gx
might introduce complexity into the training logic without clear justification. It could potentially lead to confusion about the role and usage of gradients in distributed training setups, where the focus should ideally be on synchronizing parameters that affect the model globally.
3. Conceptual Clarity
- Layer-Specific Calculations: Backpropagation is inherently a layer-specific calculation where each layer computes its output gradients based on inputs from the previous layer (or subsequent layer in the case of backpropagation). This sequence ensures that the error signal is propagated correctly backward through the network, adjusting each layer's parameters based on the local gradient computations.
-
Isolation of Layers: Each layer's parameters are updated based on gradients computed from its outputs. Synchronizing
gx
would imply an unnecessary linkage between the gradient flows of different parts of the network across different GPUs, which is not how neural network training is typically conceptualized or implemented.
4. Theoretical Considerations
-
Local vs. Global Gradients: While parameter gradients need to reflect global data characteristics (hence the synchronization),
gx
reflects how local computations at each layer should adjust to minimize the loss. Its value inherently depends on local data and layer-specific dynamics, which are contextually isolated from similar computations on other GPUs.
Conclusion
While not "wrong," synchronizing gx
is not standard practice and does not align with the principles of efficient, effective distributed training. Keeping gx
local optimizes computational resources and aligns with the foundational concepts of neural network training, where each part of the network learns based on the data it processes directly. The key to efficient distributed training is to minimize overhead while ensuring global consistency where it impacts the model—primarily at the level of model parameters, not intermediate gradients like gx
.
@lumliolum I'm looking into running_ema now. Have a feeling this needs to be synchronised, but let's see.
@lumliolum Yep your correct. Here's the final code instead. Please do let me know if there's any logic error still. I had tried to double check with Claude and GPT4 already and doesn't seem like any error. But let me know if there's anything wrong. Thx
# torchrun --standalone --nproc_per_node=4 test.py
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
# Global variable to control printing
PRINT_ALL_RANKS = True
def controlled_print(message):
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
if PRINT_ALL_RANKS or local_rank == 0:
print(f"Rank {local_rank}: {message}")
class GroupScaling1D(nn.Module):
def __init__(self, eps=1e-5, group_num=4):
super(GroupScaling1D, self).__init__()
self.eps = eps
self.group_num = group_num
def extra_repr(self):
return f'eps={self.eps}, group={self.group_num}'
def forward(self, input):
T, B, C = input.shape[0], input.shape[1], input.shape[2]
Cg = C // self.group_num
gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
moment2 = torch.repeat_interleave(torch.mean(gn_input * gn_input, dim=3, keepdim=True),
repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
return input / torch.sqrt(moment2 + self.eps)
class PowerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
debug, warmup_iters, current_iter, mask_x):
# Original PowerFunction forward code here
ctx.eps = eps
ctx.debug = debug
current_iter = current_iter.item()
ctx.current_iter = current_iter
ctx.warmup_iters = warmup_iters
ctx.abkw = abkw
rmax = 1
N, C, H, W = x.size()
x2 = (mask_x * mask_x).mean(dim=0)
var = x2.reshape(1, C, 1, 1)
if current_iter <= warmup_iters:
z = x /(var + eps).sqrt() # dividing by sqrt(varience + eps), which is same as dividing by standard deviation
else:
z = x /(running_phi + eps).sqrt() # same thing as above, but using running stats instead
y = z
ctx.save_for_backward(z, var, weight, ema_gz)
if current_iter < warmup_iters:
running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter) # cumulative moving average
running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True)) # exponential moving average
y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)
controlled_print(f"Original Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
controlled_print(f"Original Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
controlled_print(f"Original Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
controlled_print(f"Original Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
controlled_print(f"Original Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
controlled_print(f"Original Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
debug = ctx.debug
current_iter = ctx.current_iter
warmup_iters = ctx.warmup_iters
abkw = ctx.abkw
N, C, H, W = grad_output.size()
z, var, weight, ema_gz = ctx.saved_tensors
y = z
g = grad_output * weight.reshape(1, C, 1, 1)
g = g * 1
gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)
approx_grad_g = (g - (1 - abkw) * ema_gz * z) # approx function seems to just be using CTX stored tensors
ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))
gx = 1. / torch.sqrt(var + eps) * approx_grad_g # REFER TO NOTES REGARDING BACKPROP DERIVATIVE EQUATION
grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)
is_sync = hasattr(ctx, 'process_group')
prefix = "Sync" if is_sync else "Original"
controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")
return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None
class MaskPowerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
affine=True, warmup_iters=10000, group_num=1):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
self.afwd = alpha_fwd
self.abkw = alpha_bkw
self.debug = False
self.warmup_iters = warmup_iters
self.gp = GroupScaling1D(group_num=group_num)
self.group_num = group_num
def extra_repr(self):
return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)
def forward(self, input, pad_mask=None, is_encoder=False):
shaped_input = (len(input.shape) == 2)
if shaped_input:
input = input.unsqueeze(0)
if input.dim() == 4: # N, C, H, W
N, C, H, W = input.shape
input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
T, B, C = input.shape
input = self.gp(input)
# construct the mask_input, size to be (BxL) x C: L is the real length here
if pad_mask is None:
mask_input = input.clone()
else:
# Transpose the bn_mask (B x T -> T x B)
bn_mask = ~pad_mask
bn_mask = bn_mask.transpose(0, 1)
if pad_mask is not None:
pad_size = (~bn_mask).sum()
mask_input = input[bn_mask, :]
else:
mask_input = input.clone()
mask_input = mask_input.reshape(-1, self.num_features)
input = input.permute(1, 2, 0).contiguous()
input_shape = input.size()
input = input.reshape(input.size(0), self.num_features, -1)
input = input.unsqueeze(-1)
if self.training:
self.iters.copy_(self.iters + 1)
output = PowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps, \
self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input)
else:
N, C, H, W = input.size()
var = self.running_phi
output = input / (var + self.eps).sqrt()
output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)
output = output.reshape(input_shape)
output = output.permute(2, 0, 1).contiguous()
# Reshape it.
if shaped_input:
output = output.squeeze(0)
return output
class SyncPowerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
debug, warmup_iters, current_iter, mask_x, process_group, world_size):
ctx.eps = eps
ctx.debug = debug
current_iter = current_iter.item()
ctx.current_iter = current_iter
ctx.warmup_iters = warmup_iters
ctx.abkw = abkw
ctx.process_group = process_group
ctx.world_size = world_size
N, C, H, W = x.size()
x2 = (mask_x * mask_x).mean(dim=0)
var = x2.reshape(1, C, 1, 1)
# Synchronize var across GPUs
if process_group is not None:
dist.all_reduce(var, op=dist.ReduceOp.AVG, group=process_group) # no need to divide, because already averaging
if current_iter <= warmup_iters:
z = x / (var + eps).sqrt()
else:
z = x / (running_phi + eps).sqrt()
y = z
ctx.save_for_backward(z, var, weight, ema_gz)
if current_iter < warmup_iters:
running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
# Synchronize running_phi across all processes
if process_group is not None:
torch.distributed.all_reduce(running_phi, op=torch.distributed.ReduceOp.AVG, group=process_group)
y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)
controlled_print(f"Sync Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
controlled_print(f"Sync Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
controlled_print(f"Sync Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
controlled_print(f"Sync Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
controlled_print(f"Sync Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
controlled_print(f"Sync Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")
return y
@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
debug = ctx.debug
current_iter = ctx.current_iter
warmup_iters = ctx.warmup_iters
abkw = ctx.abkw
N, C, H, W = grad_output.size()
z, var, weight, ema_gz = ctx.saved_tensors
y = z
g = grad_output * weight.reshape(1, C, 1, 1)
g = g * 1
gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)
approx_grad_g = (g - (1 - abkw) * ema_gz * z)
ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))
if ctx.process_group is not None:
dist.all_reduce(ema_gz, op=dist.ReduceOp.AVG, group=ctx.process_group)
gx = 1. / torch.sqrt(var + eps) * approx_grad_g
grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)
is_sync = hasattr(ctx, 'process_group')
prefix = "Sync" if is_sync else "Original"
if is_sync:
process_group = ctx.process_group
world_size = ctx.world_size
dist.all_reduce(grad_weight, op=dist.ReduceOp.AVG, group=process_group)
dist.all_reduce(grad_bias, op=dist.ReduceOp.AVG, group=process_group)
controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")
return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None
class SyncMaskPowerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
affine=True, warmup_iters=10000, group_num=1, process_group=None):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
self.process_group = process_group
self.world_size = dist.get_world_size(process_group) if process_group else 1
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
# Synchronize buffers
dist.all_reduce(self.running_phi, op=dist.ReduceOp.AVG, group=self.process_group)
dist.all_reduce(self.ema_gz, op=dist.ReduceOp.AVG, group=self.process_group)
# dist.all_reduce(self.iters, op=dist.ReduceOp.MAX, group=self.process_group)
self.afwd = alpha_fwd
self.abkw = alpha_bkw
self.eps = eps
self.debug = False
self.warmup_iters = warmup_iters
self.gp = GroupScaling1D(group_num=group_num)
self.group_num = group_num
def extra_repr(self):
return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)
def forward(self, input, pad_mask=None, is_encoder=False):
shaped_input = (len(input.shape) == 2)
if shaped_input:
input = input.unsqueeze(0)
if input.dim() == 4: # N, C, H, W
N, C, H, W = input.shape
input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)
T, B, C = input.shape
input = self.gp(input)
if pad_mask is None:
mask_input = input.clone()
else:
bn_mask = ~pad_mask
bn_mask = bn_mask.transpose(0, 1)
if pad_mask is not None:
pad_size = (~bn_mask).sum()
mask_input = input[bn_mask, :]
else:
mask_input = input.clone()
mask_input = mask_input.reshape(-1, self.num_features)
input = input.permute(1, 2, 0).contiguous()
input_shape = input.size()
input = input.reshape(input.size(0), self.num_features, -1)
input = input.unsqueeze(-1)
if self.training:
self.iters.copy_(self.iters + 1) # maybe consider syncing this, but unlikely
output = SyncPowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps,
self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input,
self.process_group, self.world_size)
else:
N, C, H, W = input.size()
var = self.running_phi
output = input / (var + self.eps).sqrt()
output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)
output = output.reshape(input_shape)
output = output.permute(2, 0, 1).contiguous()
if shaped_input:
output = output.squeeze(0)
return output
class TestModel(nn.Module):
def __init__(self, norm_layer):
super(TestModel, self).__init__()
self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.norm = norm_layer(64)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.norm(self.conv(x)))
def run_test(local_rank, world_size):
dist.init_process_group(backend="nccl", init_method='env://', world_size=world_size, rank=local_rank)
# Set same seed for all processes
torch.manual_seed(42)
# Generate same data for all processes
batch_size = 32
torch.manual_seed(42)
data = torch.randn(batch_size, 3, 64, 64).cuda(local_rank)
# Ensure all processes have the same data
dist.broadcast(data, src=0)
if local_rank == 0:
# Run original MaskPowerNorm on single GPU
torch.manual_seed(42)
model_original = TestModel(lambda num_features: MaskPowerNorm(num_features)).cuda(local_rank)
out_original = model_original(data)
loss_original = out_original.sum()
loss_original.backward()
controlled_print("Running original PowerNorm on single GPU")
controlled_print(f"Original output mean: {out_original.mean().item()}")
# Run SyncMaskPowerNorm on all GPUs
torch.manual_seed(42)
model_sync = TestModel(lambda num_features: SyncMaskPowerNorm(num_features, process_group=dist.group.WORLD)).cuda(local_rank)
model_sync = DDP(model_sync, device_ids=[local_rank])
controlled_print("Running SyncPowerNorm")
out_sync = model_sync(data)
controlled_print(f"Sync output mean on rank {local_rank}: {out_sync.mean().item()}")
loss_sync = out_sync.sum()
loss_sync.backward()
if local_rank == 0:
for (name_o, param_o), (name_s, param_s) in zip(model_original.named_parameters(), model_sync.named_parameters()):
if param_o.grad is not None and param_s.grad is not None:
grad_diff = (param_o.grad - param_s.grad).abs().mean().item()
controlled_print(f"Gradient difference for {name_o}:")
controlled_print(f" Original - mean: {param_o.grad.mean().item()}, std: {param_o.grad.std().item()}")
controlled_print(f" Sync - mean: {param_s.grad.mean().item()}, std: {param_s.grad.std().item()}")
controlled_print(f" Absolute difference: {grad_diff}")
dist.barrier()
dist.destroy_process_group()
def main():
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
controlled_print(f"Running on rank {local_rank} of {world_size}")
run_test(local_rank, world_size)
if __name__ == "__main__":
main()
Only realised just now that this normalisation layer is meant for ViTs instead of language transformers
But in the paper, they were using this norm layer for machine translation and language modeling task.
Here's the final code instead
You didn't paste total final code. Only pasted the forward of nn.Module
so can you post autograd.Function
@lumliolum ah sorry, there you go. Please try and check my full code instead. I recommend Tensordock if you wanna get started. Something like 4x Nvidia A4000 already works, which costs 0.5 USD per hour.
But in the paper, they were using this norm layer for machine translation and language modeling task.
Ah I see. I am just trying to make sure that I'm able to convert my .pth transformer model to huggingface, then I will personally start training my models with a node of 8x H100, soon.
@lumliolum Hi. just wanna check in. Any issues hence forth?