Get Flux working on Apple Silicon
Should fix #1103
we would love to see Flux work on Silicon Mac
Yup waiting impatiently to try Q versions or maybe the NF4
Should fix #1103
Also should it not be possible to set it to bfloat16 too or set it as an option, if I remember right starting PyTorch 2.3.0 ist is supported on MPS. I did try it once in Invoke.
Ok apparently bfloat16 only works on M2 or newer, still would be nice to have it ;)
Should fix #1103
Also should it not be possible to set it to bfloat16 too or set it as an option, if I remember right starting PyTorch 2.3.0 ist is supported on MPS. I did try it once in Invoke.
Ok apparently bfloat16 only works on M2 or newer, still would be nice to have it ;)
Tried using bfloat16 on my M3, but got the following error:
RuntimeError: "arange_mps" not implemented for 'BFloat16'
@DenOfEquity is there anything I should do to get this approved?
I guess none of the active collaborators/maintainers can actually test what's going on with MPS. Also, some confusion in the linked issue about it working or not, or only with some models. But it doesn't/can't break anything, so if it helps at least sometimes I'm calling it progress. I'm also curious if it can't be float32 all the time - works for me, 100% identical results, sample size of 1.
Models I've tested: Schnell,Dev and GUFF all work ok with this "fix". Arguably the difference is GUFF is same as Dev but with less VRAM usage and Schnell is just different, resolves with less steps but looks like another thing:
NF4 wont work for obvious reasons (Bits and Bytes not being ported to Mac)
for me it seems none of the FP8 checkpoints works, I get: Trying to convert Float8_e4m3fn to the MPS backend but it does not have support for that dtype.
Even if I select the Fp16 T5.
The GGUF version I tried did not work either, error: Unsupported type byte size: UInt16
The full FP16 Flux works but it is horribly slow on my M3 Pro, about 6 min for 20 steps.
Also not sure why bfloat16 does not work, even with pytorch 2.4.1. or nightly.
Upadate: or mybe it does work but the code needs to be different.
actually if I do this: if pos.device.type == "mps": scale = torch.arange(0, dim, 2, dtype=torch.float16, device=pos.device) / dim
and use torch nightly, need to recheck 2.4.1., then bfloat16 works it seems as I get this. K-Model Created: {'storage_dtype': torch.bfloat16, 'computation_dtype': torch.bfloat16}
took 5:40 min on Torch 2.6 nightly
works with pytorch 2.4.1. too
However it does not matter really as the FP8 and de Q4_1 gguf still give the same error.
Mac M2 +1