ai-toolkit
ai-toolkit copied to clipboard
Any point to use xformers with Flux?
First of all, thanks for this great repo! I'm running out of vram on 24gb trying to train a 128 rank Flux lora. When I install xformers and try to use it, I get an error:
Traceback (most recent call last):
File "D:\ai-toolkit\run.py", line 90, in <module>
main()
File "D:\ai-toolkit\run.py", line 86, in main
raise e
File "D:\ai-toolkit\run.py", line 78, in main
job.run()
File "D:\ai-toolkit\jobs\ExtensionJob.py", line 22, in run
process.run()
File "D:\ai-toolkit\jobs\process\BaseSDTrainProcess.py", line 1701, in run
loss_dict = self.hook_train_loop(batch)
File "D:\ai-toolkit\extensions_built_in\sd_trainer\SDTrainer.py", line 1483, in hook_train_loop
noise_pred = self.predict_noise(
File "D:\ai-toolkit\extensions_built_in\sd_trainer\SDTrainer.py", line 891, in predict_noise
return self.sd.predict_noise(
File "D:\ai-toolkit\toolkit\stable_diffusion_model.py", line 1650, in predict_noise
noise_pred = self.unet(
File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ai-toolkit\venv\lib\site-packages\diffusers\models\transformers\transformer_flux.py", line 400, in forward
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
File "D:\ai-toolkit\venv\lib\site-packages\torch\_compile.py", line 31, in inner
return disable_fn(*args, **kwargs)
File "D:\ai-toolkit\venv\lib\site-packages\torch\_dynamo\eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
File "D:\ai-toolkit\venv\lib\site-packages\torch\utils\checkpoint.py", line 488, in checkpoint
ret = function(*args, **kwargs)
File "D:\ai-toolkit\venv\lib\site-packages\diffusers\models\transformers\transformer_flux.py", line 395, in custom_forward
return module(*inputs)
File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\ai-toolkit\venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "D:\ai-toolkit\venv\lib\site-packages\diffusers\models\transformers\transformer_flux.py", line 201, in forward
attn_output, context_attn_output = self.attn(
ValueError: not enough values to unpack (expected 2, got 1)