min-dalle icon indicating copy to clipboard operation
min-dalle copied to clipboard

thanks (it's 10x faster than JAX)!

Open Birch-san opened this issue 2 years ago • 14 comments

I've been trying to get dalle-playground running performantly on M1, but there's a lot of work remaining to make the JAX model work via IREE/Vulkan.

so, I tried out your pytorch model,

with a recent nightly of pytorch:

pip install --pre "torch>1.13.0.dev20220610" "torchvision>0.14.0.dev20220609" --extra-index-url https://download.pytorch.org/whl/nightly/cpu

…and it's 10x faster at dalle-mega than dalle-playground was on JAX/XLA!

using dalle-mega full:

wandb artifact get --root=./pretrained/dalle_bart_mega dalle-mini/dalle-mini/mega-1:latest

generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)!
GPU looks less-than-half utilized. haven't checked whether pytorch is the process that's using the GPU.

these measurements are from M1 Max.

bonus
"crystal maiden and lina enjoying a pint together at a tavern"
generated

Birch-san avatar Jun 29 '22 00:06 Birch-san

Awesome!

kuprel avatar Jun 29 '22 02:06 kuprel

generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)! GPU looks less-than-half utilized. haven't checked whether pytorch is the process that's using the GPU.

I think the model runs on CPU by default. I tried to move all models and tensors to the mps device and fix some incompatibilities (a few ops are not yet supported by the MPS backend). Inference was faster and GPU utilization was close to 100%, but generation did not work properly. I'm still trying to identify what the problem could be.

pcuenca avatar Jul 02 '22 00:07 pcuenca

@pcuenca wait, you got it running on-GPU? and it was faster? that's massively different from the result I got.

here's how I made it run on MPS:
https://github.com/Birch-san/min-dalle/compare/Birch-san:min-dalle:main...Birch-san:min-dalle:mps
there's other stuff in that branch too like generating multiple images, re-using text encoding between images, measuring how long each step takes.

what I found was that it ran way slower. I left it overnight and it didn't finish generating even 1 image (got to the 145th token of 255, something like that).
and tbh the CPU usage (~117%) and GPU usage (less than half) looked identical to when it ran on-CPU.

did I do something wrong? I just slapped device_type on everything I could.
I'm using torch==1.13.0.dev20220628 (recent nightly).
ran with PYTORCH_ENABLE_MPS_FALLBACK=1, --mega --torch --text='kunkka playing basketball with Shrek' --copies=3. with dalle-mega proper, not the fp16 version.
only one operation had to fallback use the fallback-to-CPU, aten::sort.values_stable.

Birch-san avatar Jul 02 '22 00:07 Birch-san

generation did not work properly

it's worth knowing that the MPS backend does have some silent errors where it will produce incorrect output (or at least transfer the wrong result to CPU). here's the really wacky phenomenon that I found:
https://github.com/pytorch/pytorch/issues/79383

Birch-san avatar Jul 02 '22 00:07 Birch-san

@Birch-san These are my changes so far: https://github.com/kuprel/min-dalle/compare/main...pcuenca:min-dalle:mps-device

I tried to use workarounds for unsupported ops, except for multinomial. You need to use PYTORCH_ENABLE_MPS_FALLBACK=1 for the backend to automatically fall back to the CPU when it encounters that operation. I also tried to replace it with argmax, which should produce something reasonable, but it did not help with generation.

I may have introduced a problem somewhere, but if you disable the MPS device by returning self here, everything works right.

pcuenca avatar Jul 02 '22 00:07 pcuenca

it's worth knowing that the MPS backend does have some silent errors where it will produce incorrect output. here's the really wacky one I found: pytorch/pytorch#79383

That's very interesting. I'll try to debug generation tomorrow. Thanks!

pcuenca avatar Jul 02 '22 00:07 pcuenca

I was also looking into getting this model on the phone. Apple says that for transformers in pytorch, the dimensions aren't in optimal order for the neural engine: https://machinelearning.apple.com/research/neural-engine-transformers

They also convert all the linear layers to convs and use a different einsum pattern

kuprel avatar Jul 02 '22 03:07 kuprel

I was also looking into getting this model on the phone. Apple says that for transformers in pytorch, the dimensions aren't in optimal order for the neural engine: https://machinelearning.apple.com/research/neural-engine-transformers

They also convert all the linear layers to convs and use a different einsum pattern

that's just the neural engine. PyTorch's MPS backend targets the GPU, and JAX's IREE/Vulkan backend does too. Dunno what Tensorflow targets. but I'll definitely take "targeting 48 GPU cores" as a step up from "targeting 10 CPU cores".

it sounds like the Neural Engine is not suitable for training anyway, only inferencing:
https://github.com/pytorch/pytorch/issues/47688#issuecomment-1066193714

Birch-san avatar Jul 02 '22 10:07 Birch-san

The neural engine is much faster than the GPU, so it makes sense to apply those optimizations. Not all operations are supported, however, and it's hard to know whether the system decided to run your model in the neural engine or the GPU.

I wasn't trying to do that yet, though. I just wanted to test inference in the MPS backend (GPU) of my M1 mac to see how it compares with the CPU and with nVidia GPUs. If we did a conversion to Core ML, we would then be able to test neural engine inference speed vs PyTorch+MPS performance.

pcuenca avatar Jul 02 '22 10:07 pcuenca

@pcuenca

That's very interesting. I'll try to debug generation tomorrow. Thanks!

If it is indeed the problem of transferring from MPS to CPU, then we should try @qqaatw's idea for transferring as contiguous memory.

https://github.com/pytorch/pytorch/issues/79383#issuecomment-1172879881

Birch-san avatar Jul 02 '22 11:07 Birch-san

@pcuenca if I slap .contiguous() at the end of every torch.{reshape,view,unsqueeze,permute}() (i.e. functions which perform reshaping, and which may utilize a view to do so): we get an image that is merely bad rather than pitch-black:
generated generated
kunkka playing basketball with Shrek

https://github.com/Birch-san/min-dalle/commit/8b832319a3f490053c489ed80d2d7b27e436be56

Birch-san avatar Jul 02 '22 12:07 Birch-san

oh, there's one final reshape() that I missed. but adding .contiguous() to that makes things worse rather than better:

generated
kunkka playing basketball with Shrek

https://github.com/Birch-san/min-dalle/commit/43e7e92dff37789a5c2a25e9dfd25fa00d277581

Birch-san avatar Jul 02 '22 12:07 Birch-san

I also tried using .contiguous() on any tensor that would be transferred to the MPS device:
https://github.com/Birch-san/min-dalle/commit/b1cf6c284a949d23f8c0cd6802bb207b876bf2af

still black.

Birch-san avatar Jul 02 '22 13:07 Birch-san

Even faster these days: you get a 4x4 grid instead of a 3x3 grid on Replicate, after the same duration.

However, this is based on Dall-E MEGA instead of Dall-E Mini, so results might differ. Not sure if better or worse.

woctezuma avatar Jul 04 '22 16:07 woctezuma