Megatron-DeepSpeed
Megatron-DeepSpeed copied to clipboard
clone HF's `GPT2` to create `GPTMeg` with a few tiny changes.
As can be seen from https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/121 we have a divergence between Meg and HF GPT2, while using the same weights under fp16.
So the proposed solution to enable users to use BigScience-pretrained models is to create a new architecture, which would be an identical clone of HF's GPT2, but with some changes.
Here are 3 changes:
def apply_overrides():
# 1. layer norm needs to be done in fp32 and then cast back to fp16 to match meg.
torch_layer_norm_orig = torch.layer_norm
def torch_layer_norm_force_fp32(input, normalized_shape, weight, bias, eps, cuddn):
out = torch_layer_norm_orig(input.float(), normalized_shape, weight.float(), bias.float(), eps, torch.backends.cudnn.enabled).half()
print(out)
#die
return out
torch.layer_norm = torch_layer_norm_force_fp32
# 2. MLP uses a slightly different activation function with a custom bwd
import transformers.activations
@torch.jit.script
def gelu_megatron_fwd(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
@torch.jit.script
def gelu_megatron_bwd(g, x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return gelu_megatron_fwd(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = gelu_megatron_bwd(grad_output, input)
return tmp, tmp
transformers.activations.gelu_fast = GeLUFunction.apply
transformers.activations.ACT2FN["gelu_fast"] = transformers.activations.gelu_fast
# 3. torch.baddbmm() (meg) produces slightly different results than torch.matmul, so override to use `torch.baddbmm`
import transformers.models.gpt2.modeling_gpt2
from torch import nn
def new_attn(self, query, key, value, attention_mask=None, head_mask=None):
output_size = (query.size(0), key.size(1), query.size(2), key.size(2))
matmul_result = torch.empty(output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query.dtype, device=query.device)
factor = float(value.size(-1)) ** 0.5
matmul_result = torch.baddbmm(
matmul_result,
query.reshape(-1, query.shape[2], query.shape[3]), # [b * np, sq, hn]
key.reshape(-1, query.shape[2], query.shape[3]).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=1.0 / factor
)
attn_weights = matmul_result.view(*output_size)
# attn_weights = torch.matmul(query, key.transpose(-1, -2))
#
# if self.scale_attn_weights:
# attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
transformers.models.gpt2.modeling_gpt2.GPT2Attention._attn = new_attn
Here is how we are going to tackle the activation function: https://github.com/huggingface/transformers/issues/13997
So a PR will need to be files with https://github.com/huggingface/transformers/
If all source files could be easily identified this perhaps the cloning could be done in a few perl one liners. Here is a very rough outline:
- find the pertinent source files grep -Irl GPT2 .
- rename files/dirs while copying s/gpt2/gpt_meg/
- rename internals to s/GPT2/GPTMeg/g
The hard to automate part is the index files as they is only one of each
thanks for the write-up. I can work on this.
@sIncerass Let me know if there is anything I can help!
FYI, we've created the fork to integrate changes we need on transformers: https://github.com/bigscience-workshop/transformers feel free to make those changes there, and we'll merge back on transformers when everything's ready?
We already have a PR https://github.com/huggingface/transformers/pull/14084 - nothing is holding us back from merging it, other than making sure it does the right thing.
Yes of course. The fork is just here to centralize all contributions to bigscience. If you are to create a PR, we should merge on the official repository, and update this fork. I will update the doc soon.