grok-1
grok-1 copied to clipboard
PyTorch huggingface transformers implementation by keyfan.
Someone made a full huggingface implementation which is way better than mine, so use this instead!
https://huggingface.co/keyfan/grok-1-hf
Previous comment
Very rough implementation, may be broken
Part of the code is copied from huggingface's Mixtral implementation. The attention for Grok-1 is not standard so I think flash attention cannot be used. Tested very little, but the input "1 + 1 = " outputs "2" so I guess it's kind of working. If anyone can expand on this to do a huggingface transformers port it's greatly appreciated! Weights uploaded to https://huggingface.co/v2ray/grok-1-pytorch.
>>> inp = torch.tensor([[2]+t.encode("2 + 3 =")+[130089]])
>>> with torch.no_grad():
... r = m(inp)
...
>>> r = torch.nn.functional.softmax(r[0][:, -1, :], dim=-1)
>>> a, b = torch.topk(r, 1, dim=-1)
>>> b
tensor([[19]])
>>> t.decode(19)
'5'
I tried to run this on 8xA10 in 4 bit was getting this error. Unless I miscalculated, it should be able to fit in 8xA10(24gb) right?
ValueError:
Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit
the quantized mode l. If you want to dispatch the model on the CPU or the disk while keeping
these modules in 32-bit, you need to set load_in_8bit_fp32_cpu_offload=True and pass a custom
device_map to from_pretrained. Check
https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu
for more details.
It did load on cpu, but I wasn't able to do inference. But ive uploaded the 4bit version here anyway in case anyone can test https://huggingface.co/eastwind/grok-1-hf-4bit
@nivibilla Just tried loading it on 8x A100 80GB and it was using 20GB vram in each GPU. For your case maybe it's because the device_map="auto" miscalculated some usage and placed some of the modules on CPU, could you try again and use this device map?
{'transformer.in_out_embed': 0, 'lm_head': 0, 'transformer.decoder_layer.0': 0, 'transformer.decoder_layer.1': 0, 'transformer.decoder_layer.2': 0, 'transformer.decoder_layer.3': 0, 'transformer.decoder_layer.4': 0, 'transformer.decoder_layer.5': 0, 'transformer.decoder_layer.6': 1, 'transformer.decoder_layer.7': 1, 'transformer.decoder_layer.8': 1, 'transformer.decoder_layer.9': 1, 'transformer.decoder_layer.10': 1, 'transformer.decoder_layer.11': 1, 'transformer.decoder_layer.12': 1, 'transformer.decoder_layer.13': 1, 'transformer.decoder_layer.14': 2, 'transformer.decoder_layer.15': 2, 'transformer.decoder_layer.16': 2, 'transformer.decoder_layer.17': 2, 'transformer.decoder_layer.18': 2, 'transformer.decoder_layer.19': 2, 'transformer.decoder_layer.20': 2, 'transformer.decoder_layer.21': 2, 'transformer.decoder_layer.22': 3, 'transformer.decoder_layer.23': 3, 'transformer.decoder_layer.24': 3, 'transformer.decoder_layer.25': 3, 'transformer.decoder_layer.26': 3, 'transformer.decoder_layer.27': 3, 'transformer.decoder_layer.28': 3, 'transformer.decoder_layer.29': 3, 'transformer.decoder_layer.30': 4, 'transformer.decoder_layer.31': 4, 'transformer.decoder_layer.32': 4, 'transformer.decoder_layer.33': 4, 'transformer.decoder_layer.34': 4, 'transformer.decoder_layer.35': 4, 'transformer.decoder_layer.36': 4, 'transformer.decoder_layer.37': 4, 'transformer.decoder_layer.38': 5, 'transformer.decoder_layer.39': 5, 'transformer.decoder_layer.40': 5, 'transformer.decoder_layer.41': 5, 'transformer.decoder_layer.42': 5, 'transformer.decoder_layer.43': 5, 'transformer.decoder_layer.44': 5, 'transformer.decoder_layer.45': 5, 'transformer.decoder_layer.46': 6, 'transformer.decoder_layer.47': 6, 'transformer.decoder_layer.48': 6, 'transformer.decoder_layer.49': 6, 'transformer.decoder_layer.50': 6, 'transformer.decoder_layer.51': 6, 'transformer.decoder_layer.52': 6, 'transformer.decoder_layer.53': 6, 'transformer.decoder_layer.54': 7, 'transformer.decoder_layer.55': 7, 'transformer.decoder_layer.56': 7, 'transformer.decoder_layer.57': 7, 'transformer.decoder_layer.58': 7, 'transformer.decoder_layer.59': 7, 'transformer.decoder_layer.60': 7, 'transformer.decoder_layer.61': 7, 'transformer.decoder_layer.62': 7, 'transformer.decoder_layer.63': 7, 'transformer.rms_norm': 7}
LGTM +1 without running I cannot tell if there are any logical issues, but quality looks good.