Easy-Transformer icon indicating copy to clipboard operation
Easy-Transformer copied to clipboard

Add mixed precision inference incl loading

Open neelnanda-io opened this issue 1 year ago • 11 comments

Add the option to load models in bfloat16 and float16. Esp important for large models like GPT-J and GPT-NeoX.

Ideally, load from HuggingFace in this low precision, do weight processing on the CPU, and then move the processed model weights to the GPU. Might be easiest to do the weight processing once and caching to HF (see #103 )

neelnanda-io avatar Dec 19 '22 11:12 neelnanda-io

Maybe covered by #125

neelnanda-io avatar Jan 20 '23 14:01 neelnanda-io

Solved with #298

Edit : actually not solved yet, there are still problems with HookedTransformer.generate, and perhaps optimizations to do. I'm preparing a commit.

glerzing avatar Jun 03 '23 17:06 glerzing

Before I found this issue, I didn't @glerzing was working on #317 so I was planning to report separately.

Anyway, despite progress, I thought I'd share a demo where I get nans running in float16:

import torch
from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

# Issue #1: there's no way to use float16 on initialization so we're forced to
# convert to float16.
model32 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped")
print(repr(model32.to_string(model32(" Unable")[0, -1].argmax())))

# Issue #2:
model16 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped").to(
    torch.float16
)
print(model16(" Unable"))

Outputs:

Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer
' to'
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer
Changing model dtype to torch.float16
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       dtype=torch.float16)

Thanks for the progress on #317 !!! Psyched to see it merged.

tbenthompson avatar Jun 13 '23 00:06 tbenthompson

I believe that Pythia 70m can have attention scores as low as -100,000, which will get you nans in float16 because those can do max -65,536. Honestly, my take is that this is not our problem, and you should use bfloat16 instead, so long as HuggingFace *also *gives you nans. I have no clue why Pythia is this high lol.

On Tue, 13 Jun 2023 at 01:55, Ben Thompson @.***> wrote:

Before I found this issue, I didn't @glerzing https://github.com/glerzing was working on #317 https://github.com/neelnanda-io/TransformerLens/pull/317 so I was planning to report separately.

Anyway, despite progress, I thought I'd share a demo where I get nans running in float16:

import torch from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

Issue #1: there's no way to use float16 on initialization so we're forced to

convert to float16.

model32 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped") print(repr(model32.to_string(model32(" Unable")[0, -1].argmax())))

Issue #2:

model16 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped").to( torch.float16 ) print(model16(" Unable"))

Thanks for the progress on #317 https://github.com/neelnanda-io/TransformerLens/pull/317 !!! Psyched to see it merged.

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/104#issuecomment-1588341142, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKIWFH77YISXJNTSBR3XK627RANCNFSM6AAAAAATDIORBA . You are receiving this because you authored the thread.Message ID: @.***>

neelnanda-io avatar Jun 13 '23 09:06 neelnanda-io

Wow, that's fascinating about the giant attention scores!!

I'm seeing big differences in both bfloat16 and float16 between Huggingface and TL on Pythia 410M. I was suspicious that the TL processing (fold LN, center unembed, etc) was causing the differences so I tried from_pretrained_no_processing but the differences persist.

I'm gradually learning more about the internals of TL so if I have time soon, I'll dig in on this and try to figure out what's going on.

Source
import torch
from transformer_lens import HookedTransformer
import transformers

torch.set_grad_enabled(False)

model_name = f"EleutherAI/pythia-410m-deduped"
model32 = HookedTransformer.from_pretrained_no_processing(model_name)
logits32 = model32(" Unable", prepend_bos=False)[0, -1]
p32 = torch.softmax(logits32, dim=-1)
del model32

model16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.float16)
logits16 = model16(" Unable", prepend_bos=False)[0, -1]
p16 = torch.softmax(logits16, dim=-1)
del model16

modelB16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.bfloat16)
logitsB16 = modelB16(" Unable", prepend_bos=False)[0, -1]
pB16 = torch.softmax(logitsB16, dim=-1)
del modelB16

tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

hf_model32 = transformers.GPTNeoXForCausalLM.from_pretrained(
    model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32
).cuda()
hf_logits32 = hf_model32(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :]
hf_p32 = torch.softmax(hf_logits32, dim=-1)
del hf_model32

hf_modelB16 = transformers.GPTNeoXForCausalLM.from_pretrained(
    model_name, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16
).cuda()
hf_logitsB16 = hf_modelB16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :]
hf_pB16 = torch.softmax(hf_logitsB16, dim=-1)
del hf_modelB16

hf_model16 = transformers.GPTNeoXForCausalLM.from_pretrained(
    model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16
).cuda()
hf_logits16 = hf_model16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :]
hf_p16 = torch.softmax(hf_logits16, dim=-1)
del hf_model16

print(f'TL, float32    top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={p32.max().item():.3f}')
print(f'TL, bfloat16   top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={pB16.max().item():.3f}')
print(f'TL, float16    top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={p16.max().item():.3f}')
print(f'HF, float32    top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={hf_p32.max().item():.3f}')
print(f'HF, bfloat16   top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={hf_pB16.max().item():.3f}')
print(f'HF, float16    top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={hf_p16.max().item():.3f}')

Output

TL, float32    top_token=' to'            p=0.745
TL, bfloat16   top_token=' to'            p=0.641
TL, float16    top_token='\n'             p=0.000
HF, float32    top_token=' to'            p=0.745
HF, bfloat16   top_token=' to'            p=0.746
HF, float16    top_token='\n'             p=0.750

For HF: the float32 and bfloat16 results are "good". The float16 results are bad but I'm assuming that's just that the some internal activations are out of range for fp16! For TL: the float32 results are good but something is going wrong beyond the expected numerical issues with both bfloat16 and float16.

tbenthompson avatar Jun 13 '23 12:06 tbenthompson

I've also observed this - my weak guess is that it's due to implementation details like the use of einsum and reshaping of attention matrices? I'm not super sure otherwise what would be implemented differently.

It would be interesting to me to carefully go through each activation and compare between TL and HF, and I'd be curious what you find!

On Tue, 13 Jun 2023 at 13:26, Ben Thompson @.***> wrote:

Wow, that's fascinating about the giant attention scores!!

I'm seeing big differences in both bfloat16 and float16 between Huggingface and TL on Pythia 410M. I was suspicious that the TL processing (fold LN, center unembed, etc) was causing the differences so I tried from_pretrained_no_processing but the differences persist.

I'm gradually learning more about the internals of TL so if I have time soon, I'll dig in on this and try to figure out what's going on. Source

import torch from transformer_lens import HookedTransformer import transformers

torch.set_grad_enabled(False)

model_name = f"EleutherAI/pythia-410m-deduped" model32 = HookedTransformer.from_pretrained_no_processing(model_name) logits32 = model32(" Unable", prepend_bos=False)[0, -1] p32 = torch.softmax(logits32, dim=-1) del model32

model16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.float16) logits16 = model16(" Unable", prepend_bos=False)[0, -1] p16 = torch.softmax(logits16, dim=-1) del model16

modelB16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.bfloat16) logitsB16 = modelB16(" Unable", prepend_bos=False)[0, -1] pB16 = torch.softmax(logitsB16, dim=-1) del modelB16

tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

hf_model32 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32 ).cuda() hf_logits32 = hf_model32(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_p32 = torch.softmax(hf_logits32, dim=-1) del hf_model32

hf_modelB16 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ).cuda() hf_logitsB16 = hf_modelB16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_pB16 = torch.softmax(hf_logitsB16, dim=-1) del hf_modelB16

hf_model16 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16 ).cuda() hf_logits16 = hf_model16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_p16 = torch.softmax(hf_logits16, dim=-1) del hf_model16

print(f'TL, float32 top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={p32.max().item():.3f}') print(f'TL, bfloat16 top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={pB16.max().item():.3f}') print(f'TL, float16 top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={p16.max().item():.3f}') print(f'HF, float32 top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={hf_p32.max().item():.3f}') print(f'HF, bfloat16 top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={hf_pB16.max().item():.3f}') print(f'HF, float16 top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={hf_p16.max().item():.3f}')

Output

TL, float32 top_token=' to' p=0.745 TL, bfloat16 top_token=' to' p=0.641 TL, float16 top_token='\n' p=0.000 HF, float32 top_token=' to' p=0.745 HF, bfloat16 top_token=' to' p=0.746 HF, float16 top_token='\n' p=0.750

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/104#issuecomment-1589205548, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKKGLYRIHCWJ27I63ODXLBMAZANCNFSM6AAAAAATDIORBA . You are receiving this because you authored the thread.Message ID: @.***>

neelnanda-io avatar Jun 13 '23 13:06 neelnanda-io

I think these two changes fix the float16 issue:

  1. Keep LayerNorm in float32
  2. Apply attention scale before computing attention scores. So instead of dividing by attention_scale, divide both q and k by sqrt(attention_scale)
class LayerNorm(nn.Module):
    ...

    def forward():
        x_type = x.dtype
        x = x.to(torch.float32)

        x = x - x.mean(axis=-1, keepdim=True)  # [batch, pos, length]
        scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
            (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
        )
        x = x / scale  # [batch, pos, length]

        return self.hook_normalized(x * self.w + self.b).to(x_type)


class Attention(nn.Module):
    def __init__():
        ...
        self.attn_scale = np.sqrt(np.sqrt(self.cfg.d_head))
        ...

    def forward():
        ...
        q = q / self.attn_scale
        k = k / self.attn_scale 

        attn_scores = (
            einsum(
                "batch query_pos head_index d_head, \
                    batch key_pos head_index d_head \
                    -> batch head_index query_pos key_pos",
                q,
                k,
            )
            #/ self.attn_scale # REMOVE THIS LINE
        )
        ...

Now running the above script gives

TL, float32    top_token=' to'            p=0.745
TL, bfloat16   top_token=' to'            p=0.641
TL, float16    top_token=' to'            p=0.740
Details

Looks like LayerNorm should stay in float32 https://github.com/pytorch/pytorch/issues/66707

When running the above test script, I saw that attention scores are reasonable for most of the forward pass but get very large (negative) in the last two blocks. Results in -inf when running in float16 without the attention fix. HuggingFace implementation uses `torch.baddbmm` which does both matmul and scaling in one operation.

slavachalnev avatar Jun 19 '23 13:06 slavachalnev

I believe that Pythia 70m can have attention scores as low as -100,000, which will get you nans in float16 because those can do max -65,536. Honestly, my take is that this is not our problem, and you should use bfloat16 instead, so long as HuggingFace *also *gives you nans. I have no clue why Pythia is this high lol.

So I think Theo Horsley might have discovered why: Pythia models have large bias vectors on the K and Q values (He said the K bias vector for one head was like norm 300 which is especially silly given its just a constant offset). At least in normal supervised ML you don't apply L2 regularization to bias terms, so I assume similarly there is no weight decay on the attn biases and so they end up large and blow up the attn scores.

wesg52 avatar Jul 27 '23 17:07 wesg52

Interesting! Note that Pythia uses rotary attention, where b_K does matter (the key gets rotated by the difference in positions, so it doesn't cancel out between different source tokens)

On Thu, 27 Jul 2023, 6:31 pm Wes Gurnee, @.***> wrote:

I believe that Pythia 70m can have attention scores as low as -100,000, which will get you nans in float16 because those can do max -65,536. Honestly, my take is that this is not our problem, and you should use bfloat16 instead, so long as HuggingFace *also *gives you nans. I have no clue why Pythia is this high lol.

So I think Theo Horsley might have discovered why: Pythia models have large bias vectors on the K and Q values (He said the K bias vector for one head was like norm 300 which is especially silly given its just a constant offset). At least in normal supervised ML you don't apply L2 regularization to bias terms, so I assume similarly there is no weight decay on the attn biases and so they end up large and blow up the attn scores.

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/104#issuecomment-1654080123, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKJYMCQ2N3VWD64LKA3XSKQY7ANCNFSM6AAAAAATDIORBA . You are receiving this because you authored the thread.Message ID: @.***>

neelnanda-io avatar Jul 27 '23 20:07 neelnanda-io

I just re-ran the test above with TL 1.5.0 and I'm getting much better results but there are still noticeable discrepancies from the HF implementation:

TL, float32    top_token=' to'            p=0.745
TL, bfloat16   top_token=' to'            p=0.730
TL, float16    top_token=' to'            p=0.737
HF, float32    top_token=' to'            p=0.745
HF, bfloat16   top_token=' to'            p=0.746
HF, float16    top_token=' to'            p=0.750

Since Pythia was trained in float16, we should probably ignore the bfloat16 comparison, but the discrepancy in float16 is still noticeable.

Thanks glerzing and slavachalnev for the improvements!

tbenthompson avatar Aug 09 '23 18:08 tbenthompson

I think the main explanation for the differences is the fact that TransformerLens uses einsum instead of Linear layers or Conv1D.

glerzing avatar Aug 09 '23 20:08 glerzing