Megatron-DeepSpeed icon indicating copy to clipboard operation
Megatron-DeepSpeed copied to clipboard

clone HF's `GPT2` to create `GPTMeg` with a few tiny changes.

Open stas00 opened this issue 3 years ago • 6 comments

As can be seen from https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/121 we have a divergence between Meg and HF GPT2, while using the same weights under fp16.

So the proposed solution to enable users to use BigScience-pretrained models is to create a new architecture, which would be an identical clone of HF's GPT2, but with some changes.

Here are 3 changes:


def apply_overrides():

    # 1. layer norm needs to be done in fp32 and then cast back to fp16 to match meg.
    torch_layer_norm_orig = torch.layer_norm
    def torch_layer_norm_force_fp32(input, normalized_shape, weight, bias, eps, cuddn):
        out = torch_layer_norm_orig(input.float(), normalized_shape, weight.float(), bias.float(), eps, torch.backends.cudnn.enabled).half()
        print(out)
        #die
        return out
    torch.layer_norm = torch_layer_norm_force_fp32


    # 2. MLP uses a slightly different activation function with a custom bwd
    import transformers.activations
    @torch.jit.script
    def gelu_megatron_fwd(x):
        return  x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))

    @torch.jit.script
    def gelu_megatron_bwd(g, x):
        tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
        # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
        ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
        return ff*g

    class GeLUFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            ctx.save_for_backward(input)
            return gelu_megatron_fwd(input)

        @staticmethod
        def backward(ctx, grad_output):
            input = ctx.saved_tensors
            tmp = gelu_megatron_bwd(grad_output, input)
            return tmp, tmp

    transformers.activations.gelu_fast = GeLUFunction.apply
    transformers.activations.ACT2FN["gelu_fast"] = transformers.activations.gelu_fast


    # 3. torch.baddbmm() (meg) produces slightly different results than torch.matmul, so override to use `torch.baddbmm`
    import transformers.models.gpt2.modeling_gpt2
    from torch import nn
    def new_attn(self, query, key, value, attention_mask=None, head_mask=None):
        output_size = (query.size(0), key.size(1), query.size(2), key.size(2))
        matmul_result = torch.empty(output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query.dtype, device=query.device)

        factor = float(value.size(-1)) ** 0.5
        matmul_result = torch.baddbmm(
            matmul_result,
            query.reshape(-1, query.shape[2], query.shape[3]),  # [b * np, sq, hn]
            key.reshape(-1, query.shape[2], query.shape[3]).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=1.0 / factor
        )
        attn_weights = matmul_result.view(*output_size)

        # attn_weights = torch.matmul(query, key.transpose(-1, -2))
        #
        # if self.scale_attn_weights:
        #     attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)

        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
            attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.Softmax(dim=-1)(attn_weights)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

    transformers.models.gpt2.modeling_gpt2.GPT2Attention._attn = new_attn

Here is how we are going to tackle the activation function: https://github.com/huggingface/transformers/issues/13997

So a PR will need to be files with https://github.com/huggingface/transformers/

stas00 avatar Oct 16 '21 03:10 stas00

If all source files could be easily identified this perhaps the cloning could be done in a few perl one liners. Here is a very rough outline:

  1. find the pertinent source files grep -Irl GPT2 .
  2. rename files/dirs while copying s/gpt2/gpt_meg/
  3. rename internals to s/GPT2/GPTMeg/g

The hard to automate part is the index files as they is only one of each

stas00 avatar Oct 16 '21 03:10 stas00

thanks for the write-up. I can work on this.

sIncerass avatar Oct 16 '21 04:10 sIncerass

@sIncerass Let me know if there is anything I can help!

jaketae avatar Oct 16 '21 07:10 jaketae

FYI, we've created the fork to integrate changes we need on transformers: https://github.com/bigscience-workshop/transformers feel free to make those changes there, and we'll merge back on transformers when everything's ready?

thomasw21 avatar Oct 26 '21 11:10 thomasw21

We already have a PR https://github.com/huggingface/transformers/pull/14084 - nothing is holding us back from merging it, other than making sure it does the right thing.

stas00 avatar Oct 26 '21 16:10 stas00

Yes of course. The fork is just here to centralize all contributions to bigscience. If you are to create a PR, we should merge on the official repository, and update this fork. I will update the doc soon.

thomasw21 avatar Oct 26 '21 16:10 thomasw21