ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: fine-tune with stable diffusion

Open ray0809 opened this issue 2 years ago • 3 comments

🐛 Describe the bug

when I running the training example with cifar10, It broken with an error RuntimeError: Expected is_sm80 to be true, but got false.

As mentioned from https://github.com/HazyResearch/flash-attention/issues/51, the author said the dim_head must be 128

but the input dim_head of SD‘s attention modules is lower 128

Environment

GTX 3090 cuda 11.3 torch 1.12.1 colossalai 0.1.10+torch1.12cu11.3

Screenshots

image

By the way, the use_ema tag is false in training config, when I turned it on, it threw another error

English is not my native language, please excuse some of my grammar mistakes

ray0809 avatar Nov 14 '22 06:11 ray0809

Flash attention is not a stable feature in ColossalAi, Flash attention is an approximation algorithm for attention, the dim size must be the second power of 2

Fazziekey avatar Nov 15 '22 02:11 Fazziekey

@Fazziekey thanks for your response In the colossalai's stable diffusion example, the flash_attn's tag is disable, but the readme tells us that using flash could save much gpu memory. So I turn it on

After pad the dim size to the second power of 2, like this

def _pad(self, q, k, v, dim_head):
    pad = 0
    if dim_head <= 32:
        pad = 32 - dim_head
    elif dim_head <= 64:
        pad = 64 - dim_head
    elif dim_head <= 128:
        pad = 128 - dim_head
    else:
        raise ValueError(f'Head size ${dim_head} too large for Flash Attention')
    if pad:
        q = torch.nn.functional.pad(q, (0, self.heads*pad), value=0)
        k = torch.nn.functional.pad(k, (0, self.heads*pad), value=0)
        v = torch.nn.functional.pad(v, (0, self.heads*pad), value=0)

    return q, k, v, pad
dim_head = int(dim_head)
q, k, v, pad = self._pad(q, k, v, dim_head)
# print("in flash")
print(">>> q, k, v:", q.shape, k.shape, v.shape)
if q.shape[1] == k.shape[1]:
    out = self._flash_attention_qkv(q, k, v)
else:
    out = self._flash_attention_q_kv(q, k, v)
if pad:
    out = out[..., :self.heads*dim_head]

But still the same error occurs image

ray0809 avatar Nov 17 '22 03:11 ray0809

we fix the use_ema in https://github.com/hpcaitech/ColossalAI/pull/1986/files

Fazziekey avatar Nov 18 '22 09:11 Fazziekey

ok, Thank you~

ray0809 avatar Nov 22 '22 06:11 ray0809