grok-1 icon indicating copy to clipboard operation
grok-1 copied to clipboard

PyTorch huggingface transformers implementation by keyfan.

Open LagPixelLOL opened this issue 11 months ago • 4 comments

Someone made a full huggingface implementation which is way better than mine, so use this instead!

https://huggingface.co/keyfan/grok-1-hf

Previous comment

Very rough implementation, may be broken

Part of the code is copied from huggingface's Mixtral implementation. The attention for Grok-1 is not standard so I think flash attention cannot be used. Tested very little, but the input "1 + 1 = " outputs "2" so I guess it's kind of working. If anyone can expand on this to do a huggingface transformers port it's greatly appreciated! Weights uploaded to https://huggingface.co/v2ray/grok-1-pytorch.

>>> inp = torch.tensor([[2]+t.encode("2 + 3 =")+[130089]])
>>> with torch.no_grad():
...  r = m(inp)
... 
>>> r = torch.nn.functional.softmax(r[0][:, -1, :], dim=-1)
>>> a, b = torch.topk(r, 1, dim=-1)
>>> b
tensor([[19]])
>>> t.decode(19)
'5'

LagPixelLOL avatar Mar 20 '24 10:03 LagPixelLOL

I tried to run this on 8xA10 in 4 bit was getting this error. Unless I miscalculated, it should be able to fit in 8xA10(24gb) right?

ValueError:
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
the quantized mode l. If you want to dispatch the model on the CPU or the disk while keeping
these modules in 32-bit, you need to set load_in_8bit_fp32_cpu_offload=True and pass a custom
device_map to from_pretrained. Check
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
for more details.

nivibilla avatar Mar 21 '24 11:03 nivibilla

It did load on cpu, but I wasn't able to do inference. But ive uploaded the 4bit version here anyway in case anyone can test https://huggingface.co/eastwind/grok-1-hf-4bit

nivibilla avatar Mar 21 '24 11:03 nivibilla

@nivibilla Just tried loading it on 8x A100 80GB and it was using 20GB vram in each GPU. For your case maybe it's because the device_map="auto" miscalculated some usage and placed some of the modules on CPU, could you try again and use this device map?

{'transformer.in_out_embed': 0, 'lm_head': 0, 'transformer.decoder_layer.0': 0, 'transformer.decoder_layer.1': 0, 'transformer.decoder_layer.2': 0, 'transformer.decoder_layer.3': 0, 'transformer.decoder_layer.4': 0, 'transformer.decoder_layer.5': 0, 'transformer.decoder_layer.6': 1, 'transformer.decoder_layer.7': 1, 'transformer.decoder_layer.8': 1, 'transformer.decoder_layer.9': 1, 'transformer.decoder_layer.10': 1, 'transformer.decoder_layer.11': 1, 'transformer.decoder_layer.12': 1, 'transformer.decoder_layer.13': 1, 'transformer.decoder_layer.14': 2, 'transformer.decoder_layer.15': 2, 'transformer.decoder_layer.16': 2, 'transformer.decoder_layer.17': 2, 'transformer.decoder_layer.18': 2, 'transformer.decoder_layer.19': 2, 'transformer.decoder_layer.20': 2, 'transformer.decoder_layer.21': 2, 'transformer.decoder_layer.22': 3, 'transformer.decoder_layer.23': 3, 'transformer.decoder_layer.24': 3, 'transformer.decoder_layer.25': 3, 'transformer.decoder_layer.26': 3, 'transformer.decoder_layer.27': 3, 'transformer.decoder_layer.28': 3, 'transformer.decoder_layer.29': 3, 'transformer.decoder_layer.30': 4, 'transformer.decoder_layer.31': 4, 'transformer.decoder_layer.32': 4, 'transformer.decoder_layer.33': 4, 'transformer.decoder_layer.34': 4, 'transformer.decoder_layer.35': 4, 'transformer.decoder_layer.36': 4, 'transformer.decoder_layer.37': 4, 'transformer.decoder_layer.38': 5, 'transformer.decoder_layer.39': 5, 'transformer.decoder_layer.40': 5, 'transformer.decoder_layer.41': 5, 'transformer.decoder_layer.42': 5, 'transformer.decoder_layer.43': 5, 'transformer.decoder_layer.44': 5, 'transformer.decoder_layer.45': 5, 'transformer.decoder_layer.46': 6, 'transformer.decoder_layer.47': 6, 'transformer.decoder_layer.48': 6, 'transformer.decoder_layer.49': 6, 'transformer.decoder_layer.50': 6, 'transformer.decoder_layer.51': 6, 'transformer.decoder_layer.52': 6, 'transformer.decoder_layer.53': 6, 'transformer.decoder_layer.54': 7, 'transformer.decoder_layer.55': 7, 'transformer.decoder_layer.56': 7, 'transformer.decoder_layer.57': 7, 'transformer.decoder_layer.58': 7, 'transformer.decoder_layer.59': 7, 'transformer.decoder_layer.60': 7, 'transformer.decoder_layer.61': 7, 'transformer.decoder_layer.62': 7, 'transformer.decoder_layer.63': 7, 'transformer.rms_norm': 7}

LagPixelLOL avatar Mar 21 '24 23:03 LagPixelLOL

LGTM +1 without running I cannot tell if there are any logical issues, but quality looks good.

Aareon avatar Mar 22 '24 23:03 Aareon