PaLM
PaLM copied to clipboard
Training with Hidet compiler
Hello! I was wondering if there was anything extra that needed to be done to get training with Hidet compiler working.
Out of the box I seem to be running into errors
import torch
from palm_rlhf_pytorch import PaLM
palm = PaLM(
num_tokens = 20000,
dim = 512,
depth = 12,
flash_attn = True, # https://arxiv.org/abs/2205.14135
cross_entropy_ignore_index = 0
).to(torch.bfloat16).cuda()
palm_opt = torch.compile(palm, backend='hidet')
seq = torch.randint(0, 20000, (1, 1024)).cuda()
loss = palm_opt(seq, return_loss = True)
loss.backward()
Some of the errors I faced were around the usage of rearrange here and here.
It also seems like einsum isn't supported. Even after replacing those OPs with equivalent alternatives. I'm still running into some reshape errors from hidet
AssertionError: , occurred when interpreting reshape with
tensor_reshape(tensor(...), [1023])
I can post additional info as needed, but wondering if you ran into those same errors or if I'm doing something incorrectly.
Thanks!