mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Img2Img in mlx-examples/stable_diffusion fails with reshape mismatch

Open pudepiedj opened this issue 1 year ago • 5 comments

I have only just started exploring all this wonderful MLX stuff but when I try to run the image2image.py script on a simple .png I am getting an error that seems to arise during the unet.py reshaping of x process using [B, H, W, C]:

ValueError: All the input array dimensions must match exactly except for the concatenation axis. However, the provided shapes are (8,18,18,1280), (8,17,17,1280), and the concatenation axis is -1

Somehow we go from [8,17,17,1280] to [8,18,18,1280] during def __call__() but I can't see an obvious reason or fix. Since I am only using one [528, 528, 3] image I don't think I can have caused this, but anything is possible! My bad if so.

Incidentally, the im2im.png in the repo can't be read as original.py as in the sample prompt because of the addition of the other fireplace images.

pudepiedj avatar Jan 09 '24 14:01 pudepiedj

I am able to replicate this. Seems to be a mismatch of the residual_hidden_state that are saved during downsampling compared to the resulting shape when x is upsampled. They should be the same but I suspect a padding issue...trying to find the exact line where things go wrong.


Edit: Yeah, looking at it a bit more this is a common issue in UNet-like architectures where skip-connections are involved. The problem is that downsampling a shape of (17,17) gives (9,9) which when upsampled results in (18,18) - hence the error.

I'll try to draft up a PR for a solution to this problem.

LeonEricsson avatar Jan 09 '24 15:01 LeonEricsson

I am able to replicate this. Seems to be a mismatch of the residual_hidden_state that are saved during downsampling compared to the resulting shape when x is upsampled. They should be the same but I suspect a padding issue...trying to find the exact line where things go wrong.

Edit: Yeah, looking at it a bit more this is a common issue in UNet-like architectures where skip-connections are involved. The problem is that downsampling a shape of (17,17) gives (9,9) which when upsampled results in (18,18) - hence the error.

I'll try to draft up a PR for a solution to this problem.

Thank you for investigating. I did the padding exercise as you suggested to make the image [544, 544,3] but I still got the same mismatch of dimensions as reported before.

In the end I decided that the problem arises because even 544 = 17 x 32 and I think we need an even multiple of 32, so a multiple of 64, and when I force-padded the image to be 576 using mx.pad((24,24),(24,24),(0,0)) everything worked. Such a restriction on the original image doesn't seem very satisfactory, but at least we are getting somewhere.

This is the code I then concocted with some diagnostic printing which needs to go right at the top of image2image.py just after sd = StableDiffusion():

    sd = StableDiffusion()

    # Read the image
    img = mx.array(np.array(Image.open(args.image)))
    print(f"Loaded image shape: {img.shape}"). # Output: [528, 528, 3] in this case

    img = (img[:, :, :3].astype(mx.float32) / 255) * 2 - 1

    # Make image dimensions an even multiple of 64 = 2^6
    target_size = 64*round((img.shape[0]+32) / 64)   # Even multiple of 64

    pad_amt = target_size - img.shape[0] # Amt to pad each side

    pad_params = ((pad_amt // 2, pad_amt - pad_amt // 2), (pad_amt // 2, pad_amt - pad_amt //2),(0,0))
    
    img = mx.pad(img, pad_params)

    print(f"\033[33mImage size = {img.shape}\033[0m") # Output: [576, 576, 3]

Original (ironically produced by txt2img.py): image Product: image

pudepiedj avatar Jan 09 '24 19:01 pudepiedj

Thank you for investigating. I did the padding exercise as you suggested to make the image [544, 544,3] but I still got the same mismatch of dimensions as reported before.

In the end I decided that the problem arises because even 544 = 17 x 32 and I think we need an even multiple of 32, so a multiple of 64, and when I force-padded the image to be 576 using mx.pad((24,24),(24,24),(0,0)) everything worked. Such a restriction on the original image doesn't seem very satisfactory, but at least we are getting somewhere.

Great that you got it working, I wasn't sure how many UNetBlocks you were working with. The number of UNetBlocks should decide, 2^nr_blocks, what your resolution needs to be divisible by and hence what you need to pad to. My understanding is that this is inherent to the U-Net architecture.

LeonEricsson avatar Jan 09 '24 19:01 LeonEricsson

Great that you got it working, I wasn't sure how many UNetBlocks you were working with. The number of UNetBlocks should decide, 2^nr_blocks, what your resolution needs to be divisible by and hence what you need to pad to. My understanding is that this is inherent to the U-Net architecture.

Thank you. I was just running the code out of the box and assuming that it would use any image supplied, which turned out not to be true. As I am sure we agree, it's find to have to go into or understand the code during a development phase but it isn't satisfactory if it's supposed to be for a general user who will just want to 'plug and play'. Your PR should address part of this.

Another problem, as I discovered trying to do [1024, 1024,3] is that larger images than [576, 576,3] can easily generate runtime errors because of lack of resources. That should probably be trapped somewhere too. Worth raising as an issue, or too obvious?

pudepiedj avatar Jan 10 '24 11:01 pudepiedj

Another problem, as I discovered trying to do [1024, 1024,3] is that larger images than [576, 576,3] can easily generate runtime errors because of lack of resources. That should probably be trapped somewhere too. Worth raising as an issue, or too obvious?

Well that is completely dependant on the user's system. I feel like a runtime error is a reasonable enough error to inform the user that it needs to use smaller input images. Trying to pre-compute the real-time available resources is not something that is common.

LeonEricsson avatar Jan 10 '24 13:01 LeonEricsson