Alex McKinney

Results 71 comments of Alex McKinney

Update: I now have a full model working. I haven't checked if the pretrained weight loading wrappers (provided by the Flax GPTNeo implementation) work yet, but once they are it...

Okay, thanks for the guidance and helper scripts 🔥 I expected that this lack of precision was not normal 😅 I'll get the pretrained wrappers working first and then focus...

I've begun my hunt for numerical bugs 🐛 The first I squashed was rather strange. It seems `torch.rsqrt` and `jax.lax.rsqrt` do not match. This is used in the RMSNorm layers....

@sanchit-gandhi The model now numerically matches in fp32 on CPU. The issue was my backend has changed from CPU to GPU since fixing the `rsqrt` issue. I don't think we...

Awesome thanks, tests and docs it is! I am currently on leave so won't be progressing on this until the 31st. > That's interesting - are we loading the weights...

Hi, currently been pretty split responsibility wise (moving house and job !!) so have only made a small bit of progress. Most of the tests pass, however, there seems to...

The final tests ended up being easy to fix: I had simply forgotten to swap the attention mask and position ids in the pretrained model wrapper. @sanchit-gandhi I haven't retested...

@sanchit-gandhi all tests pass locally :tada: And I've also ran the model using the `generate` API to see if the outputs make sense: ``` In [23]: inputs = tokenizer('Aloha, World!',...

Thanks for your additional comments, I have some time to work on the more involved points today 🤗

@sanchit-gandhi I think everything except the missing weight issue is resolved now (see my comment). Trying to resolve some remaining CI issues, I noticed that the line `# Copied from...