mlx-examples
mlx-examples copied to clipboard
Img2Img in mlx-examples/stable_diffusion fails with reshape mismatch
I have only just started exploring all this wonderful MLX stuff but when I try to run the image2image.py
script on a simple .png
I am getting an error that seems to arise during the unet.py
reshaping of x
process using [B, H, W, C]
:
ValueError: All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (8,18,18,1280), (8,17,17,1280), and the concatenation axis is -1
Somehow we go from [8,17,17,1280]
to [8,18,18,1280]
during def __call__()
but I can't see an obvious reason or fix. Since I am only using one [528, 528, 3]
image I don't think I can have caused this, but anything is possible! My bad if so.
Incidentally, the im2im.png
in the repo can't be read as original.py
as in the sample prompt because of the addition of the other fireplace images.
I am able to replicate this. Seems to be a mismatch of the residual_hidden_state
that are saved during downsampling compared to the resulting shape when x
is upsampled. They should be the same but I suspect a padding issue...trying to find the exact line where things go wrong.
Edit: Yeah, looking at it a bit more this is a common issue in UNet-like architectures where skip-connections are involved. The problem is that downsampling a shape of (17,17)
gives (9,9)
which when upsampled results in (18,18)
- hence the error.
I'll try to draft up a PR for a solution to this problem.
I am able to replicate this. Seems to be a mismatch of the
residual_hidden_state
that are saved during downsampling compared to the resulting shape whenx
is upsampled. They should be the same but I suspect a padding issue...trying to find the exact line where things go wrong.Edit: Yeah, looking at it a bit more this is a common issue in UNet-like architectures where skip-connections are involved. The problem is that downsampling a shape of
(17,17)
gives(9,9)
which when upsampled results in(18,18)
- hence the error.I'll try to draft up a PR for a solution to this problem.
Thank you for investigating. I did the padding exercise as you suggested to make the image [544, 544,3]
but I still got the same mismatch of dimensions as reported before.
In the end I decided that the problem arises because even 544 = 17 x 32
and I think we need an even multiple of 32
, so a multiple of 64
, and when I force-padded the image to be 576
using mx.pad((24,24),(24,24),(0,0))
everything worked. Such a restriction on the original image doesn't seem very satisfactory, but at least we are getting somewhere.
This is the code I then concocted with some diagnostic printing which needs to go right at the top of image2image.py just after sd = StableDiffusion()
:
sd = StableDiffusion()
# Read the image
img = mx.array(np.array(Image.open(args.image)))
print(f"Loaded image shape: {img.shape}"). # Output: [528, 528, 3] in this case
img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1
# Make image dimensions an even multiple of 64 = 2^6
target_size = 64*round((img.shape[0]+32) / 64) # Even multiple of 64
pad_amt = target_size - img.shape[0] # Amt to pad each side
pad_params = ((pad_amt // 2, pad_amt - pad_amt // 2), (pad_amt // 2, pad_amt - pad_amt //2),(0,0))
img = mx.pad(img, pad_params)
print(f"\033[33mImage size = {img.shape}\033[0m") # Output: [576, 576, 3]
Original (ironically produced by txt2img.py):
Product:
Thank you for investigating. I did the padding exercise as you suggested to make the image
[544, 544,3]
but I still got the same mismatch of dimensions as reported before.In the end I decided that the problem arises because even
544 = 17 x 32
and I think we need an even multiple of32
, so a multiple of64
, and when I force-padded the image to be576
usingmx.pad((24,24),(24,24),(0,0))
everything worked. Such a restriction on the original image doesn't seem very satisfactory, but at least we are getting somewhere.
Great that you got it working, I wasn't sure how many UNetBlocks you were working with. The number of UNetBlocks should decide, 2^nr_blocks
, what your resolution needs to be divisible by and hence what you need to pad to. My understanding is that this is inherent to the U-Net architecture.
Great that you got it working, I wasn't sure how many UNetBlocks you were working with. The number of UNetBlocks should decide,
2^nr_blocks
, what your resolution needs to be divisible by and hence what you need to pad to. My understanding is that this is inherent to the U-Net architecture.
Thank you. I was just running the code out of the box and assuming that it would use any image supplied, which turned out not to be true. As I am sure we agree, it's find to have to go into or understand the code during a development phase but it isn't satisfactory if it's supposed to be for a general user who will just want to 'plug and play'. Your PR should address part of this.
Another problem, as I discovered trying to do [1024, 1024,3]
is that larger images than [576, 576,3]
can easily generate runtime errors because of lack of resources. That should probably be trapped somewhere too. Worth raising as an issue, or too obvious?
Another problem, as I discovered trying to do
[1024, 1024,3]
is that larger images than[576, 576,3]
can easily generate runtime errors because of lack of resources. That should probably be trapped somewhere too. Worth raising as an issue, or too obvious?
Well that is completely dependant on the user's system. I feel like a runtime error is a reasonable enough error to inform the user that it needs to use smaller input images. Trying to pre-compute the real-time available resources is not something that is common.