onediff icon indicating copy to clipboard operation
onediff copied to clipboard

[Bug] Flux compilation does not support num_images_per_prompt greater than 1

Open dqueue opened this issue 1 year ago • 0 comments

Your current environment information

Current env:

PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OneFlow version: none
Nexfort version: 0.1.dev317+torch251cu124
OneDiff version: 1.0.1.dev187+gc8f513ec
OneDiffX version: none

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.35

Python version: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A40
Nvidia driver version: 535.154.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             6
On-line CPU(s) list:                0-5
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz
CPU family:                         6
Model:                              106
Thread(s) per core:                 2
Core(s) per socket:                 3
Socket(s):                          1
Stepping:                           6
BogoMIPS:                           5799.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm md_clear arch_capabilities
Virtualization:                     VT-x
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          144 KiB (3 instances)
L1i cache:                          96 KiB (3 instances)
L2 cache:                           3.8 MiB (3 instances)
L3 cache:                           24 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-5
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] diffusers==0.32.2
[pip3] nexfort==0.1.dev317+torch251cu124
[pip3] numpy==1.26.0
[pip3] onnx==1.16.0
[pip3] onnx-graphsurgeon==0.5.2
[pip3] pytorch-lightning==2.2.4
[pip3] torch==2.5.1
[pip3] torch-optimi==0.2.1
[pip3] torchaudio==2.3.0+cu121
[pip3] torchmetrics==1.4.0.post0
[pip3] torchsde==0.2.6
[pip3] torchvision==0.20.1
[pip3] transformers==4.39.3
[pip3] triton==3.1.0
[conda] nexfort                   0.1.dev317+torch251cu124          pypi_0    pypi
[conda] numpy                     1.26.0                   pypi_0    pypi
[conda] pytorch-lightning         2.2.4                    pypi_0    pypi
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torch-optimi              0.2.1                    pypi_0    pypi
[conda] torchaudio                2.3.0+cu121              pypi_0    pypi
[conda] torchmetrics              1.4.0.post0              pypi_0    pypi
[conda] torchsde                  0.2.6                    pypi_0    pypi
[conda] torchvision               0.20.1                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi

🐛 Describe the bug

Trying to compile any Flux pipeline and running with num_images_per_prompt > 1 will result in the following error:

File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux_fill.py:916](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux_fill.py:916), in FluxFillPipeline.__call__(self, prompt, prompt_2, image, mask_image, masked_image_latents, height, width, num_inference_steps, sigmas, guidance_scale, num_images_per_prompt, generator, latents, prompt_embeds, pooled_prompt_embeds, output_type, return_dict, joint_attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    913 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
    914 timestep = t.expand(latents.shape[0]).to(latents.dtype)
--> 916 noise_pred = self.transformer(
    917     hidden_states=torch.cat((latents, masked_image_latents), dim=2),
    918     timestep=timestep / 1000,
    919     guidance=guidance,
    920     pooled_projections=pooled_prompt_embeds,
    921     encoder_hidden_states=prompt_embeds,
    922     txt_ids=text_ids,
    923     img_ids=latent_image_ids,
    924     joint_attention_kwargs=self.joint_attention_kwargs,
    925     return_dict=False,
    926 )[0]
    928 # compute the previous noisy sample x_t -> x_t-1
    929 latents_dtype = latents.dtype

File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py:565](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py:565), in FluxTransformer2DModel.forward(self, hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids, guidance, joint_attention_kwargs, controlnet_block_samples, controlnet_single_block_samples, return_dict, controlnet_blocks_repeat)
    556     hidden_states = torch.utils.checkpoint.checkpoint(
    557         create_custom_forward(block),
    558         hidden_states,
   (...)
    561         **ckpt_kwargs,
    562     )
    564 else:
--> 565     hidden_states = block(
    566         hidden_states=hidden_states,
    567         temb=temb,
    568         image_rotary_emb=image_rotary_emb,
    569         joint_attention_kwargs=joint_attention_kwargs,
    570     )
    572 # controlnet residual
    573 if controlnet_single_block_samples is not None:

File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File src/nexfort/nn/modules.py:260, in nexfort.nn.modules.FluxSingleTransformerBlock.forward()

File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/nexfort/kernels/utils.py:30](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/nexfort/kernels/utils.py:30), in with_cuda_ctx.<locals>.decorated(*args, **kwargs)
     28 assert isinstance(args[0], (torch.Tensor, torch.nn.Parameter))
     29 with torch.cuda.device(args[0].device):
---> 30     return fn(*args, **kwargs)

File src/nexfort/nn/functional.py:79, in nexfort.nn.functional.fuse_linear_mul_residual()

RuntimeError: shape '[3072]' is invalid for input of size 6144

dqueue avatar Feb 08 '25 10:02 dqueue