DALL-E icon indicating copy to clipboard operation
DALL-E copied to clipboard

Upsample recompute_scale_factor related error when running the notebook

Open sammo opened this issue 2 years ago • 3 comments

Is anyone facing this issue when running the notebook? If not, what version of pytorch are you using?

I see this method is affected in the most recent release of pytorch (1.11.0) so my guess that has something to do with it.

AttributeError: 'Upsample' object has no attribute 'recompute_scale_factor'

sammo avatar Jun 19 '22 05:06 sammo

The error 'Upsample' object has no attribute 'recompute_scale_factor'

is related to a change in the torch Upscale class from 1.10 to 1.11.

It appears that 'old' Upscale objects are saved within the model after this line of code: model = load_model("https://cdn.openai.com/dall-e/decoder.pkl", 'cuda') i used the following code immediatly after the load_model call to patch this:

# Patch for torch 1.11 and higher: replace the old Upsample object by the new version
# that exposes recompute_scale_factor
_ = model.blocks.group_1.upsample
model.blocks.group_1.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)
_ = model.blocks.group_2.upsample
model.blocks.group_2.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)
_ = model.blocks.group_3.upsample
model.blocks.group_3.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)

and it's running fine with torch 1.12.1!

digitalShaman avatar Aug 27 '22 13:08 digitalShaman

my solution is to use the state dict from the encoder/decoder online

from dall_e import Encoder, Decoder
# This can be changed to a GPU, e.g. 'cuda:0'.
dev = torch.device('cpu')
# For faster load times, download these files locally and use the local paths instead.
enc_old = load_model("https://cdn.openai.com/dall-e/encoder.pkl", dev)
dec_old = load_model("https://cdn.openai.com/dall-e/decoder.pkl", dev)
enc = Encoder()
enc.load_state_dict(enc_old.state_dict())
enc.eval()
dec = Decoder()
dec.load_state_dict(dec_old.state_dict())
dec.eval()

kamwoh avatar Nov 14 '22 21:11 kamwoh

I use an exception handler to check the status of "self.recompute_scale_factor" in upsampling.py . If that attribute doesn't exist during execution, my modified code will directly assign False to it.

It's interesting because this exception doesn't raise every time, and I am still not sure about the mechanism behind it.

lzh107u avatar Apr 24 '23 10:04 lzh107u