DALL-E
DALL-E copied to clipboard
Upsample recompute_scale_factor related error when running the notebook
Is anyone facing this issue when running the notebook? If not, what version of pytorch are you using?
I see this method is affected in the most recent release of pytorch (1.11.0) so my guess that has something to do with it.
AttributeError: 'Upsample' object has no attribute 'recompute_scale_factor'
The error 'Upsample' object has no attribute 'recompute_scale_factor'
is related to a change in the torch Upscale class from 1.10 to 1.11.
It appears that 'old' Upscale objects are saved within the model after this line of code:
model = load_model("https://cdn.openai.com/dall-e/decoder.pkl", 'cuda')
i used the following code immediatly after the load_model call to patch this:
# Patch for torch 1.11 and higher: replace the old Upsample object by the new version
# that exposes recompute_scale_factor
_ = model.blocks.group_1.upsample
model.blocks.group_1.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)
_ = model.blocks.group_2.upsample
model.blocks.group_2.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)
_ = model.blocks.group_3.upsample
model.blocks.group_3.upsample = torch.nn.Upsample(scale_factor = _.scale_factor, mode= _.mode)
and it's running fine with torch 1.12.1!
my solution is to use the state dict from the encoder/decoder online
from dall_e import Encoder, Decoder
# This can be changed to a GPU, e.g. 'cuda:0'.
dev = torch.device('cpu')
# For faster load times, download these files locally and use the local paths instead.
enc_old = load_model("https://cdn.openai.com/dall-e/encoder.pkl", dev)
dec_old = load_model("https://cdn.openai.com/dall-e/decoder.pkl", dev)
enc = Encoder()
enc.load_state_dict(enc_old.state_dict())
enc.eval()
dec = Decoder()
dec.load_state_dict(dec_old.state_dict())
dec.eval()
I use an exception handler to check the status of "self.recompute_scale_factor" in upsampling.py . If that attribute doesn't exist during execution, my modified code will directly assign False to it.
It's interesting because this exception doesn't raise every time, and I am still not sure about the mechanism behind it.