Inference error
Hi @Huage001 , I make inference_t2i.ipnyb to .py file (code is exactly the same), and try to test the acceleration in flux dev, I meet the error in pipe inference stage. I don't know where the issue is ? maybe compile method in the flex_attn and also create_block_mask?
Traceback (most recent call last): File "/CLEAR/inference_t2i.py", line 174, in <module> check_sparse_speed() File "/CLEAR/inference_t2i.py", line 126, in check_sparse_speed image = pipe( ^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/pipelines/flux/pipeline_flux.py", line 889, in __call__ noise_pred = self.transformer( ^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 524, in forward encoder_hidden_states, hidden_states = block( ^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/transformers/transformer_flux.py", line 180, in forward attention_outputs = self.attn( ^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 588, in forward return self.processor( ^^^^^^^^^^^^^^^ File "/CLEAR/attention_processor.py", line 228, in __call__ hidden_states = self.flex_attn(query, torch.cat([key, key_downsample], dim=2), torch.cat([value, value_downsample], dim=2), scale=attention_scale) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/external_utils.py", line 38, in inner @functools.wraps(fn) File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1100, in forward return compiled_fn(full_args) ^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 321, in runtime_wrapper all_outs = call_func_at_runtime_with_args( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 124, in call_func_at_runtime_with_args out = normalize_as_list(f(args)) ^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 667, in inner_fn outs = compiled_fn(args) ^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 488, in wrapper return compiled_fn(runtime_args) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1478, in __call__ return self.current_callable(inputs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/utils.py", line 1977, in run return model(new_inputs) ^^^^^^^^^^^^^^^^^ File "/tmp/torchinductor_tiger/ax/caxlewgnrr7i2npupywktxn7nuuvirr3l3qw6yxayenhemxthps5.py", line 536, in call triton_tem_fused_0.run(arg0_1, arg1_1, arg2_1, buf0, arg3_1, arg4_1, arg5_1, arg6_1, buf1, grid=torch._inductor.kernel.flex_attention.flex_attention_grid(1, 24, 4608, 128, meta0), stream=stream7) File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 879, in run return launcher( ^^^^^^^^^ File "<string>", line 13, in launcher File "my_path/miniconda3/envs/CLEAR/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__ self.launch(*args, **kwargs) ValueError: Pointer argument (at 3) cannot be accessed from Triton (cpu tensor?)
Hi,
It's strange. From the error message, it seems that the problem is on the cuda device. If you are running on a multi-GPU node, you can try adding the following codes at the very beginning of the notebook:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=x
Please change x into your GPU id you want to use. Please let me know if the error persists.