consistency_models
consistency_models copied to clipboard
QKVFlashAttention unexpected parameters error, running in Google Colab
I tried to generate samples in Colab and everything works except that I had to change this line of code in /cm/unet.py, clearing out factory_kwargs.
Not sure if this is a bug or I did something wrong. This is how I ran it: https://github.com/JonathanFly/consistency_models_colab_notebook/blob/main/Consistency_Models_Make_Samples.ipynb
class QKVFlashAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
batch_first=True,
attention_dropout=0.0,
causal=False,
device=None,
dtype=None,
**kwargs,
) -> None:
from einops import rearrange
from flash_attn.flash_attention import FlashAttention
assert batch_first
#factory_kwargs = {"device": device, "dtype": dtype}
factory_kwargs = {}
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.causal = causal
i also meet the same error, i guess this code base using a previously version of flash_attn.
Logging to /tmp/openai-2023-04-13-14-33-55-278549
creating model and diffusion...
Traceback (most recent call last):
File "/content/consistency_models/scripts/image_sample.py", line 143, in <module>
main()
File "/content/consistency_models/scripts/image_sample.py", line 37, in main
model, diffusion = create_model_and_diffusion(
File "/content/consistency_models/cm/script_util.py", line 76, in create_model_and_diffusion
model = create_model(
File "/content/consistency_models/cm/script_util.py", line 140, in create_model
return UNetModel(
File "/content/consistency_models/cm/unet.py", line 612, in __init__
AttentionBlock(
File "/content/consistency_models/cm/unet.py", line 293, in __init__
self.attention = QKVFlashAttention(channels, self.num_heads)
File "/content/consistency_models/cm/unet.py", line 359, in __init__
self.inner_attn = FlashAttention(
TypeError: __init__() got an unexpected keyword argument 'device'
Related minor issue, in th_evaluator.py, inception-2015-12-05.pt tries to download automatically but fails, and it doesn't seem like you can pass the path on the command line. Also is it supposed to automatically calculate stats a reference batch? (I'm probably trying to run the sample out of order?)
class FIDAndIS: def init( self, softmax_batch_size=512, clip_score_batch_size=512, path="https://openaipublic.blob.core.windows.net/consistency/inception/inception-2015-12-05.pt", ):
class FIDAndIS: def init(
The version of v1.0.2 has no device parameter. https://github.com/HazyResearch/flash-attention/blob/v1.0.2/flash_attn/flash_attention.py#L21
But v0.2.8 has device parameter. https://github.com/HazyResearch/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py#L21
I use pip install flash-attn==0.2.8 solved it.
I use pip install flash-attn==0.2.8 solved it.
After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.
(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir
Logging to /tmp/openai-2023-04-19-10-51-33-746807
creating model and diffusion...
creating data loader...
loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt
creating the target model
training...
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Traceback (most recent call last):
File "cm_train.py", line 171, in <module>
main()
File "cm_train.py", line 121, in main
CMTrainLoop(
File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop
self.run_step(batch, cond)
File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step
self.forward_backward(batch, cond)
File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward
losses = compute_losses()
File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses
distiller = denoise_fn(x_t, t)
File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn
return self.denoise(model, x, t, **model_kwargs)[1]
File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise
model_output = model(c_in * x_t, rescaled_t, **model_kwargs)
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bld/consistency_models/cm/unet.py", line 765, in forward
h = module(h, emb)
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bld/consistency_models/cm/unet.py", line 77, in forward
x = layer(x)
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bld/consistency_models/cm/unet.py", line 308, in forward
return checkpoint(
File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
return func(*inputs)
File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward
h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint)
File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint
return func(*inputs)
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bld/consistency_models/cm/unet.py", line 368, in forward
qkv, _ = self.inner_attn(
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward
output = flash_attn_unpadded_qkvpacked_func(
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward
out, softmax_lse, S_dmask = _flash_attn_forward(
File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward
softmax_lse, *rest = flash_attn_cuda.fwd(
RuntimeError: Expected q.stride(-1) == 1 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
--------------------------------------------------------------------------
Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
--------------------------------------------------------------------------
mpiexec detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:
Process name: [[48068,1],0]
Exit code: 1
I use pip install flash-attn==0.2.8 solved it.
After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.
(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir Logging to /tmp/openai-2023-04-19-10-51-33-746807 creating model and diffusion... creating data loader... loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt creating the target model training... /home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Traceback (most recent call last): File "cm_train.py", line 171, in <module> main() File "cm_train.py", line 121, in main CMTrainLoop( File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop self.run_step(batch, cond) File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step self.forward_backward(batch, cond) File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward losses = compute_losses() File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses distiller = denoise_fn(x_t, t) File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn return self.denoise(model, x, t, **model_kwargs)[1] File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise model_output = model(c_in * x_t, rescaled_t, **model_kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward output = self._run_ddp_forward(*inputs, **kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index] File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 765, in forward h = module(h, emb) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 77, in forward x = layer(x) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 308, in forward return checkpoint( File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint return func(*inputs) File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint) File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint return func(*inputs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 368, in forward qkv, _ = self.inner_attn( File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward output = flash_attn_unpadded_qkvpacked_func( File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward out, softmax_lse, S_dmask = _flash_attn_forward( File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward softmax_lse, *rest = flash_attn_cuda.fwd( RuntimeError: Expected q.stride(-1) == 1 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) -------------------------------------------------------------------------- Primary job terminated normally, but 1 process returned a non-zero exit code. Per user-direction, the job has been aborted. -------------------------------------------------------------------------- -------------------------------------------------------------------------- mpiexec detected that one or more processes exited with non-zero status, thus causing the job to be terminated. The first process to do so was: Process name: [[48068,1],0] Exit code: 1
I got the same problem as well
I use pip install flash-attn==0.2.8 solved it.
After this procedure, I start training the model with these parameters and then an error came. Anyone know what it means? I'm a rookie for pytorch.
(py3_8_16) bld@bld:~/consistency_models/scripts$ mpiexec -n 1 python cm_train.py --training_mode consistency_training --target_ema_mode adaptive --start_ema 0.95 --scale_mode progressive --start_scales 2 --end_scales 150 --total_training_steps 100000 --loss_norm lpips --lr_anneal_steps 0 --teacher_model_path /home/bld/pre_train_model/edm_bedroom256_ema.pt --attention_resolutions 32,16,8 --class_cond False --use_scale_shift_norm False --dropout 0.0 --teacher_dropout 0.1 --ema_rate 0.9999,0.99994,0.9999432189950708 --global_batch_size 1 --image_size 256 --lr 0.00005 --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --schedule_sampler uniform --use_fp16 True --weight_decay 0.0 --weight_schedule uniform --data_dir /home/bld/lsun/lsun_train_output_dir Logging to /tmp/openai-2023-04-19-10-51-33-746807 creating model and diffusion... creating data loader... loading the teacher model from /home/bld/pre_train_model/edm_bedroom256_ema.pt creating the target model training... /home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights. warnings.warn(msg) Traceback (most recent call last): File "cm_train.py", line 171, in <module> main() File "cm_train.py", line 121, in main CMTrainLoop( File "/home/bld/consistency_models/cm/train_util.py", line 367, in run_loop self.run_step(batch, cond) File "/home/bld/consistency_models/cm/train_util.py", line 389, in run_step self.forward_backward(batch, cond) File "/home/bld/consistency_models/cm/train_util.py", line 501, in forward_backward losses = compute_losses() File "/home/bld/consistency_models/cm/karras_diffusion.py", line 191, in consistency_losses distiller = denoise_fn(x_t, t) File "/home/bld/consistency_models/cm/karras_diffusion.py", line 125, in denoise_fn return self.denoise(model, x, t, **model_kwargs)[1] File "/home/bld/consistency_models/cm/karras_diffusion.py", line 347, in denoise model_output = model(c_in * x_t, rescaled_t, **model_kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward output = self._run_ddp_forward(*inputs, **kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index] File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 765, in forward h = module(h, emb) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 77, in forward x = layer(x) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 308, in forward return checkpoint( File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint return func(*inputs) File "/home/bld/consistency_models/cm/unet.py", line 325, in _forward h = checkpoint(self.attention, (qkv,), (), self.use_attention_checkpoint) File "/home/bld/consistency_models/cm/nn.py", line 155, in checkpoint return func(*inputs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/consistency_models/cm/unet.py", line 368, in forward qkv, _ = self.inner_attn( File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attention.py", line 47, in forward output = flash_attn_unpadded_qkvpacked_func( File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 266, in flash_attn_unpadded_qkvpacked_func return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(*args, **kwargs) # type: ignore[misc] File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 58, in forward out, softmax_lse, S_dmask = _flash_attn_forward( File "/home/bld/anaconda3/envs/py3_8_16/lib/python3.8/site-packages/flash_attn/flash_attn_interface.py", line 21, in _flash_attn_forward softmax_lse, *rest = flash_attn_cuda.fwd( RuntimeError: Expected q.stride(-1) == 1 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) -------------------------------------------------------------------------- Primary job terminated normally, but 1 process returned a non-zero exit code. Per user-direction, the job has been aborted. -------------------------------------------------------------------------- -------------------------------------------------------------------------- mpiexec detected that one or more processes exited with non-zero status, thus causing the job to be terminated. The first process to do so was: Process name: [[48068,1],0] Exit code: 1
I got the same problem as well
Had you solved this problem? I don't even know what the error message means.
Solution: Do the following changes in File "/content/consistency_models/cm/unet.py", line 359, in init
- self.inner_attn = FlashAttention(
- attention_dropout=attention_dropout, **factory_kwargs
- )
+ self.inner_attn = FlashAttention(attention_dropout=attention_dropout)
I tried that, but it didn't solve the problem, were there any other changes you made?
Since I'm running on V100, I also had to disable flash-attention (apparently it only works on A100)
index 3fe5184..d9f7c2f 100644
--- a/cm/unet.py
+++ b/cm/unet.py
@@ -270,7 +270,7 @@ class AttentionBlock(nn.Module):
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
- attention_type="flash",
+ attention_type="default", #"flash", # disable flash-attention by default in order to run on V100
encoder_channels=None,
dims=2,
channels_last=False,
Still doesn't work for me. This is what i get for CD on Imagenet 64, the similar result I get with EDM
Still doesn't work for me. This is what i get for CD on Imagenet 64, the similar result I get with EDM
I cannot obtain images of similar quality to those in the paper
@boxwayne @aarontan-git @asanakoy For the stride
issue, I think it's the rearrange issue because of flashAttn version.
qkv = self.rearrange(
qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads
)
# print(qkv.shape, qkv.stride())
qkv, _ = self.inner_attn(
qkv.contiguous(),
key_padding_mask=key_padding_mask,
need_weights=need_weights,
causal=self.causal,
)
The print result is torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256)
, which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before calling inner_attn
:
qkv=qkv.contiguous()
Let me know if that solves the issue. I tested on my side and it works.
The flash-attn I installed is version 1.0.2, no problem.
@boxwayne @aarontan-git @asanakoy For the
stride
issue, I think it's the rearrange issue because of flashAttn version.qkv = self.rearrange( qkv, "b (three h d) s -> b s three h d", three=3, h=self.num_heads ) # print(qkv.shape, qkv.stride()) qkv, _ = self.inner_attn( qkv.contiguous(), key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal, )
The print result is
torch.Size([1, 256, 3, 6, 64]) (256, 1, 98304, 16384, 256)
, which means the tensor after rearranging is no longer contiguous. (The old version might not require this while the new version requires it to be contiguous.) So I simply add a contiguous operations before callinginner_attn
:qkv=qkv.contiguous()
Let me know if that solves the issue. I tested on my side and it works.
I tried your fix, and got the following warning message when trying to run an imagenet consistency training:
Grad strides do not match bucket view strides. This may indicate grad was not created according │·
to the gradient layout contract, or that the param's strides changed since DDP was constructed. This│·
is not an error, but may impair performance. │·
grad.sizes() = [384, 384, 1, 1], strides() = [384, 1, 384, 384] │·
bucket_view.sizes() = [384, 384, 1, 1], strides() = [384, 1, 1, 1] (Triggered internally at ../torch/│·
csrc/distributed/c10d/reducer.cpp:325.)
@aarontan-git You could just leave this warning there if you don't care about it. If you want to fix this, you should check the codes and see which part involves the gradient stride change. And do the tensor storage stride modification to avoid the warning.
The flash-attn I installed is version 1.0.2, no problem.
When installing version 1.0.2, the following error will occur; How did you solve it?
The flash-attn I installed is version 1.0.2, no problem.
When installing version 1.0.2, the following error will occur; How did you solve it?
As Jonathan said at the top, change this line of code in /cm/unet.py, clearing out factory_kwargs:
class QKVFlashAttention(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
batch_first=True,
attention_dropout=0.0,
causal=False,
device=None,
dtype=None,
**kwargs,
) -> None:
from einops import rearrange
from flash_attn.flash_attention import FlashAttention
assert batch_first
#factory_kwargs = {"device": device, "dtype": dtype}
factory_kwargs = {}
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.causal = causal
我安装的闪存是1.0.2版本,没问题。
安装版本 1.0.2 时,将出现以下错误; 你是怎么解决的?
正如 Jonathan 在顶部所说,更改 /cm/unet.py 中的这一行代码,清除factory_kwargs:
class QKVFlashAttention(nn.Module): def __init__( self, embed_dim, num_heads, batch_first=True, attention_dropout=0.0, causal=False, device=None, dtype=None, **kwargs, ) -> None: from einops import rearrange from flash_attn.flash_attention import FlashAttention assert batch_first #factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {} super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.causal = causal
What GPU are you using?
This doesn't seem to have anything to do with the version of 'flash-attn'; I make 'attention_ type="flash" ---> "default", the codes can run, but the result is poor. If not changed, there will be the following error message:
my GPU is V100........
我安装的闪存是1.0.2版本,没问题。
安装版本 1.0.2 时,将出现以下错误; 你是怎么解决的?
正如 Jonathan 在顶部所说,更改 /cm/unet.py 中的这一行代码,清除factory_kwargs:
class QKVFlashAttention(nn.Module): def __init__( self, embed_dim, num_heads, batch_first=True, attention_dropout=0.0, causal=False, device=None, dtype=None, **kwargs, ) -> None: from einops import rearrange from flash_attn.flash_attention import FlashAttention assert batch_first #factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {} super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.causal = causal
What GPU are you using? This doesn't seem to have anything to do with the version of 'flash-attn'; I make 'attention_ type="flash" ---> "default", the codes can run, but the result is poor. If not changed, there will be the following error message:
my GPU is V100........
I used a single A100.