ColossalAI
ColossalAI copied to clipboard
[BUG]: fine-tune with stable diffusion
🐛 Describe the bug
when I running the training example with cifar10, It broken with an error RuntimeError: Expected is_sm80 to be true, but got false.
As mentioned from https://github.com/HazyResearch/flash-attention/issues/51, the author said the dim_head must be 128
but the input dim_head of SD‘s attention modules is lower 128
Environment
GTX 3090 cuda 11.3 torch 1.12.1 colossalai 0.1.10+torch1.12cu11.3
Screenshots
By the way, the use_ema tag is false in training config, when I turned it on, it threw another error
English is not my native language, please excuse some of my grammar mistakes
Flash attention is not a stable feature in ColossalAi, Flash attention is an approximation algorithm for attention, the dim size must be the second power of 2
@Fazziekey thanks for your response In the colossalai's stable diffusion example, the flash_attn's tag is disable, but the readme tells us that using flash could save much gpu memory. So I turn it on
After pad the dim size to the second power of 2, like this
def _pad(self, q, k, v, dim_head):
pad = 0
if dim_head <= 32:
pad = 32 - dim_head
elif dim_head <= 64:
pad = 64 - dim_head
elif dim_head <= 128:
pad = 128 - dim_head
else:
raise ValueError(f'Head size ${dim_head} too large for Flash Attention')
if pad:
q = torch.nn.functional.pad(q, (0, self.heads*pad), value=0)
k = torch.nn.functional.pad(k, (0, self.heads*pad), value=0)
v = torch.nn.functional.pad(v, (0, self.heads*pad), value=0)
return q, k, v, pad
dim_head = int(dim_head)
q, k, v, pad = self._pad(q, k, v, dim_head)
# print("in flash")
print(">>> q, k, v:", q.shape, k.shape, v.shape)
if q.shape[1] == k.shape[1]:
out = self._flash_attention_qkv(q, k, v)
else:
out = self._flash_attention_q_kv(q, k, v)
if pad:
out = out[..., :self.heads*dim_head]
But still the same error occurs
we fix the use_ema in https://github.com/hpcaitech/ColossalAI/pull/1986/files
ok, Thank you~