onediff
onediff copied to clipboard
[Bug] Flux compilation does not support num_images_per_prompt greater than 1
Your current environment information
Current env:
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OneFlow version: none
Nexfort version: 0.1.dev317+torch251cu124
OneDiff version: 1.0.1.dev187+gc8f513ec
OneDiffX version: none
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.35
Python version: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A40
Nvidia driver version: 535.154.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 6
On-line CPU(s) list: 0-5
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz
CPU family: 6
Model: 106
Thread(s) per core: 2
Core(s) per socket: 3
Socket(s): 1
Stepping: 6
BogoMIPS: 5799.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm md_clear arch_capabilities
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 144 KiB (3 instances)
L1i cache: 96 KiB (3 instances)
L2 cache: 3.8 MiB (3 instances)
L3 cache: 24 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-5
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled
Versions of relevant libraries:
[pip3] diffusers==0.32.2
[pip3] nexfort==0.1.dev317+torch251cu124
[pip3] numpy==1.26.0
[pip3] onnx==1.16.0
[pip3] onnx-graphsurgeon==0.5.2
[pip3] pytorch-lightning==2.2.4
[pip3] torch==2.5.1
[pip3] torch-optimi==0.2.1
[pip3] torchaudio==2.3.0+cu121
[pip3] torchmetrics==1.4.0.post0
[pip3] torchsde==0.2.6
[pip3] torchvision==0.20.1
[pip3] transformers==4.39.3
[pip3] triton==3.1.0
[conda] nexfort 0.1.dev317+torch251cu124 pypi_0 pypi
[conda] numpy 1.26.0 pypi_0 pypi
[conda] pytorch-lightning 2.2.4 pypi_0 pypi
[conda] torch 2.5.1 pypi_0 pypi
[conda] torch-optimi 0.2.1 pypi_0 pypi
[conda] torchaudio 2.3.0+cu121 pypi_0 pypi
[conda] torchmetrics 1.4.0.post0 pypi_0 pypi
[conda] torchsde 0.2.6 pypi_0 pypi
[conda] torchvision 0.20.1 pypi_0 pypi
[conda] triton 3.1.0 pypi_0 pypi
🐛 Describe the bug
Trying to compile any Flux pipeline and running with num_images_per_prompt > 1 will result in the following error:
File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux_fill.py:916](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux_fill.py:916), in FluxFillPipeline.__call__(self, prompt, prompt_2, image, mask_image, masked_image_latents, height, width, num_inference_steps, sigmas, guidance_scale, num_images_per_prompt, generator, latents, prompt_embeds, pooled_prompt_embeds, output_type, return_dict, joint_attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
913 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
914 timestep = t.expand(latents.shape[0]).to(latents.dtype)
--> 916 noise_pred = self.transformer(
917 hidden_states=torch.cat((latents, masked_image_latents), dim=2),
918 timestep=timestep / 1000,
919 guidance=guidance,
920 pooled_projections=pooled_prompt_embeds,
921 encoder_hidden_states=prompt_embeds,
922 txt_ids=text_ids,
923 img_ids=latent_image_ids,
924 joint_attention_kwargs=self.joint_attention_kwargs,
925 return_dict=False,
926 )[0]
928 # compute the previous noisy sample x_t -> x_t-1
929 latents_dtype = latents.dtype
File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py:565](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py:565), in FluxTransformer2DModel.forward(self, hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids, guidance, joint_attention_kwargs, controlnet_block_samples, controlnet_single_block_samples, return_dict, controlnet_blocks_repeat)
556 hidden_states = torch.utils.checkpoint.checkpoint(
557 create_custom_forward(block),
558 hidden_states,
(...)
561 **ckpt_kwargs,
562 )
564 else:
--> 565 hidden_states = block(
566 hidden_states=hidden_states,
567 temb=temb,
568 image_rotary_emb=image_rotary_emb,
569 joint_attention_kwargs=joint_attention_kwargs,
570 )
572 # controlnet residual
573 if controlnet_single_block_samples is not None:
File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1736), in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1747), in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File src/nexfort/nn/modules.py:260, in nexfort.nn.modules.FluxSingleTransformerBlock.forward()
File [~/miniconda3/envs/testenv/lib/python3.10/site-packages/nexfort/kernels/utils.py:30](https://vscode-remote+ssh-002dremote-002b204-002e52-002e16-002e56.vscode-resource.vscode-cdn.net/mnt/disk/PX-Diffusion/~/miniconda3/envs/testenv/lib/python3.10/site-packages/nexfort/kernels/utils.py:30), in with_cuda_ctx.<locals>.decorated(*args, **kwargs)
28 assert isinstance(args[0], (torch.Tensor, torch.nn.Parameter))
29 with torch.cuda.device(args[0].device):
---> 30 return fn(*args, **kwargs)
File src/nexfort/nn/functional.py:79, in nexfort.nn.functional.fuse_linear_mul_residual()
RuntimeError: shape '[3072]' is invalid for input of size 6144