RePaint icon indicating copy to clipboard operation
RePaint copied to clipboard

Guided diffusion pretrain can't sample in the Repaint

Open xiongsua opened this issue 3 years ago • 9 comments
trafficstars

I sample get mistake on the Repaint . When i train on the guided-diffusion and after use ema_pt to smaple. Because guided-diffusion code need to change? and what do i need change in the guided-diffusion code

xiongsua avatar Oct 16 '22 06:10 xiongsua

Same problem. Do you have any solution.

drakiez13 avatar Nov 02 '22 16:11 drakiez13

It seems that the model architecture has changed.

Try to use the model code from your training by copying it here.

Does that work?

andreas128 avatar Nov 05 '22 01:11 andreas128

Hi, did you sovled the problem?

zhangbaijin avatar Mar 11 '23 09:03 zhangbaijin

I sample get mistake on the Repaint . When i train on the guided-diffusion and after use ema_pt to smaple. Because guided-diffusion code need to change? and what do i need change in the guided-diffusion code

zhangbaijin avatar Mar 11 '23 09:03 zhangbaijin

@xiongsua Hi, did you sovled the problem?

xyz-xdx avatar Mar 30 '23 04:03 xyz-xdx

I meet the same question, the values in the checkpoint do not match the network, RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "i nput_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.1.norm.weight", "input_b locks.5.1.norm.bias", "input_blocks.5.1.qkv.weight", "input_blocks.5.1.qkv.bias", "input_blocks.5.1.proj_out.weight", "input_blocks.5.1.proj_out.bias", "input_blocks.6.1.norm.weight", "input_blocks.6.1.norm.bias", "input_blocks.6.1. qkv.weight", "input_blocks.6.1.qkv.bias", "input_blocks.6.1.proj_out.weight", "input_blocks.6.1.proj_out.bias", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias" , "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_b locks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "i nput_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bia s", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers .0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_la yers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7. 2.out_layers.3.bias", "output_blocks.8.1.norm.weight", "output_blocks.8.1.norm.bias", "output_blocks.8.1.qkv.weight", "output_blocks.8.1.qkv.bias", "output_blocks.8.1.proj_out.weight", "output_blocks.8.1.proj_out.bias", "output_bloc ks.9.1.norm.weight", "output_blocks.9.1.norm.bias", "output_blocks.9.1.qkv.weight", "output_blocks.9.1.qkv.bias", "output_blocks.9.1.proj_out.weight", "output_blocks.9.1.proj_out.bias", "output_blocks.10.1.norm.weight", "output_bloc ks.10.1.norm.bias", "output_blocks.10.1.qkv.weight", "output_blocks.10.1.qkv.bias", "output_blocks.10.1.proj_out.weight", "output_blocks.10.1.proj_out.bias", "output_blocks.11.1.norm.weight", "output_blocks.11.1.norm.bias", "output blocks.11.1.qkv.weight", "output_blocks.11.1.qkv.bias", "output_blocks.11.1.proj_out.weight", "output_blocks.11.1.proj_out.bias", "output_blocks.11.2.in_layers.0.weight", "output_blocks.11.2.in_layers.0.bias", "output_blocks.11.2.in _layers.2.weight", "output_blocks.11.2.in_layers.2.bias", "output_blocks.11.2.emb_layers.1.weight", "output_blocks.11.2.emb_layers.1.bias", "output_blocks.11.2.out_layers.0.weight", "output_blocks.11.2.out_layers.0.bias", "output_bl ocks.11.2.out_layers.3.weight", "output_blocks.11.2.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "output_blocks. 3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias". size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]). size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

fangtun avatar Jun 30 '23 08:06 fangtun

I meet the same question, the values in the checkpoint do not match the network, RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "i nput_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.1.norm.weight", "input_b locks.5.1.norm.bias", "input_blocks.5.1.qkv.weight", "input_blocks.5.1.qkv.bias", "input_blocks.5.1.proj_out.weight", "input_blocks.5.1.proj_out.bias", "input_blocks.6.1.norm.weight", "input_blocks.6.1.norm.bias", "input_blocks.6.1. qkv.weight", "input_blocks.6.1.qkv.bias", "input_blocks.6.1.proj_out.weight", "input_blocks.6.1.proj_out.bias", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias" , "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_b locks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "i nput_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bia s", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers .0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_la yers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7. 2.out_layers.3.bias", "output_blocks.8.1.norm.weight", "output_blocks.8.1.norm.bias", "output_blocks.8.1.qkv.weight", "output_blocks.8.1.qkv.bias", "output_blocks.8.1.proj_out.weight", "output_blocks.8.1.proj_out.bias", "output_bloc ks.9.1.norm.weight", "output_blocks.9.1.norm.bias", "output_blocks.9.1.qkv.weight", "output_blocks.9.1.qkv.bias", "output_blocks.9.1.proj_out.weight", "output_blocks.9.1.proj_out.bias", "output_blocks.10.1.norm.weight", "output_bloc ks.10.1.norm.bias", "output_blocks.10.1.qkv.weight", "output_blocks.10.1.qkv.bias", "output_blocks.10.1.proj_out.weight", "output_blocks.10.1.proj_out.bias", "output_blocks.11.1.norm.weight", "output_blocks.11.1.norm.bias", "output blocks.11.1.qkv.weight", "output_blocks.11.1.qkv.bias", "output_blocks.11.1.proj_out.weight", "output_blocks.11.1.proj_out.bias", "output_blocks.11.2.in_layers.0.weight", "output_blocks.11.2.in_layers.0.bias", "output_blocks.11.2.in _layers.2.weight", "output_blocks.11.2.in_layers.2.bias", "output_blocks.11.2.emb_layers.1.weight", "output_blocks.11.2.emb_layers.1.bias", "output_blocks.11.2.out_layers.0.weight", "output_blocks.11.2.out_layers.0.bias", "output_bl ocks.11.2.out_layers.3.weight", "output_blocks.11.2.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "output_blocks. 3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias". size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]). size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

@fangtun Hi, i have the same issue. Have you solved it?

handsomecong001 avatar Sep 27 '23 01:09 handsomecong001

Hi, I also have a same issue.

Traceback (most recent call last): File "test.py", line 180, in main(conf_arg) File "test.py", line 69, in main model.load_state_dict( File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 2001, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.3.0.in_layers.0.weight", "input_blocks.3.0.in_layers.0.bias", "input_blocks.3.0.in_layers.2.weight", "input_blocks.3.0.in_layers.2.bias", "input_blocks.3.0.emb_layers.1.weight", "input_blocks.3.0.emb_layers.1.bias", "input_blocks.3.0.out_layers.0.weight", "input_blocks.3.0.out_layers.0.bias", "input_blocks.3.0.out_layers.3.weight", "input_blocks.3.0.out_layers.3.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.9.0.in_layers.0.weight", "input_blocks.9.0.in_layers.0.bias", "input_blocks.9.0.in_layers.2.weight", "input_blocks.9.0.in_layers.2.bias", "input_blocks.9.0.emb_layers.1.weight", "input_blocks.9.0.emb_layers.1.bias", "input_blocks.9.0.out_layers.0.weight", "input_blocks.9.0.out_layers.0.bias", "input_blocks.9.0.out_layers.3.weight", "input_blocks.9.0.out_layers.3.bias", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "input_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "input_blocks.15.0.in_layers.0.weight", "input_blocks.15.0.in_layers.0.bias", "input_blocks.15.0.in_layers.2.weight", "input_blocks.15.0.in_layers.2.bias", "input_blocks.15.0.emb_layers.1.weight", "input_blocks.15.0.emb_layers.1.bias", "input_blocks.15.0.out_layers.0.weight", "input_blocks.15.0.out_layers.0.bias", "input_blocks.15.0.out_layers.3.weight", "input_blocks.15.0.out_layers.3.bias", "output_blocks.2.2.in_layers.0.weight", "output_blocks.2.2.in_layers.0.bias", "output_blocks.2.2.in_layers.2.weight", "output_blocks.2.2.in_layers.2.bias", "output_blocks.2.2.emb_layers.1.weight", "output_blocks.2.2.emb_layers.1.bias", "output_blocks.2.2.out_layers.0.weight", "output_blocks.2.2.out_layers.0.bias", "output_blocks.2.2.out_layers.3.weight", "output_blocks.2.2.out_layers.3.bias", "output_blocks.5.2.in_layers.0.weight", "output_blocks.5.2.in_layers.0.bias", "output_blocks.5.2.in_layers.2.weight", "output_blocks.5.2.in_layers.2.bias", "output_blocks.5.2.emb_layers.1.weight", "output_blocks.5.2.emb_layers.1.bias", "output_blocks.5.2.out_layers.0.weight", "output_blocks.5.2.out_layers.0.bias", "output_blocks.5.2.out_layers.3.weight", "output_blocks.5.2.out_layers.3.bias", "output_blocks.8.2.in_layers.0.weight", "output_blocks.8.2.in_layers.0.bias", "output_blocks.8.2.in_layers.2.weight", "output_blocks.8.2.in_layers.2.bias", "output_blocks.8.2.emb_layers.1.weight", "output_blocks.8.2.emb_layers.1.bias", "output_blocks.8.2.out_layers.0.weight", "output_blocks.8.2.out_layers.0.bias", "output_blocks.8.2.out_layers.3.weight", "output_blocks.8.2.out_layers.3.bias", "output_blocks.11.1.in_layers.0.weight", "output_blocks.11.1.in_layers.0.bias", "output_blocks.11.1.in_layers.2.weight", "output_blocks.11.1.in_layers.2.bias", "output_blocks.11.1.emb_layers.1.weight", "output_blocks.11.1.emb_layers.1.bias", "output_blocks.11.1.out_layers.0.weight", "output_blocks.11.1.out_layers.0.bias", "output_blocks.11.1.out_layers.3.weight", "output_blocks.11.1.out_layers.3.bias", "output_blocks.14.1.in_layers.0.weight", "output_blocks.14.1.in_layers.0.bias", "output_blocks.14.1.in_layers.2.weight", "output_blocks.14.1.in_layers.2.bias", "output_blocks.14.1.emb_layers.1.weight", "output_blocks.14.1.emb_layers.1.bias", "output_blocks.14.1.out_layers.0.weight", "output_blocks.14.1.out_layers.0.bias", "output_blocks.14.1.out_layers.3.weight", "output_blocks.14.1.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias", "input_blocks.6.0.op.weight", "input_blocks.6.0.op.bias", "input_blocks.9.0.op.weight", "input_blocks.9.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.15.0.op.weight", "input_blocks.15.0.op.bias", "output_blocks.2.2.conv.weight", "output_blocks.2.2.conv.bias", "output_blocks.5.2.conv.weight", "output_blocks.5.2.conv.bias", "output_blocks.8.2.conv.weight", "output_blocks.8.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias", "output_blocks.14.1.conv.weight", "output_blocks.14.1.conv.bias".

@handsomecong001 Have you tried it?

seungwooham avatar Mar 05 '24 01:03 seungwooham

It might not work perfect, based on the arguments. However, in my case updating UNetModel and adding SiLU in the import section worked.

`class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding.

:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
    attention will take place. May be a set, list, or tuple.
    For example, if this contains 4, then at 4x downsampling, attention
    will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
    downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
    class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
"""

def __init__(
    self,
    image_size,
    in_channels,
    model_channels,
    out_channels,
    num_res_blocks,
    attention_resolutions,
    dropout=0,
    channel_mult=(1, 2, 4, 8),
    conv_resample=True,
    dims=2,
    num_classes=None,
    use_checkpoint=False,
    use_fp16=False,
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=False,
    resblock_updown=False,
    use_new_attention_order=False,
    conf=None
):
    super().__init__()

    if num_heads_upsample == -1:
        num_heads_upsample = num_heads

    self.image_size = image_size
    self.in_channels = in_channels
    self.model_channels = model_channels
    self.out_channels = out_channels
    self.num_res_blocks = num_res_blocks
    self.attention_resolutions = attention_resolutions
    self.dropout = dropout
    self.channel_mult = channel_mult
    self.conv_resample = conv_resample
    self.num_classes = num_classes
    self.use_checkpoint = use_checkpoint
    self.dtype = th.float16 if use_fp16 else th.float32
    self.num_heads = num_heads
    self.num_head_channels = num_head_channels
    self.num_heads_upsample = num_heads_upsample
    self.conf = conf

    time_embed_dim = model_channels * 4
    self.time_embed = nn.Sequential(
        linear(model_channels, time_embed_dim),
        SiLU(),
        linear(time_embed_dim, time_embed_dim),
    )

    if self.num_classes is not None:
        self.label_emb = nn.Embedding(num_classes, time_embed_dim)

    self.input_blocks = nn.ModuleList(
        [
            TimestepEmbedSequential(
                conv_nd(dims, in_channels, model_channels, 3, padding=1)
            )
        ]
    )
    input_block_chans = [model_channels]
    ch = model_channels
    ds = 1
    for level, mult in enumerate(channel_mult):
        for _ in range(num_res_blocks):
            layers = [
                ResBlock(
                    ch,
                    time_embed_dim,
                    dropout,
                    out_channels=mult * model_channels,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                )
            ]
            ch = mult * model_channels
            if ds in attention_resolutions:
                layers.append(
                    AttentionBlock(
                        ch,
                        use_checkpoint=use_checkpoint,
                        num_heads=num_heads,
                        num_head_channels=num_head_channels,
                        use_new_attention_order=use_new_attention_order,
                    )
                )
            self.input_blocks.append(TimestepEmbedSequential(*layers))
            input_block_chans.append(ch)
        if level != len(channel_mult) - 1:
            self.input_blocks.append(
                TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))
            )
            input_block_chans.append(ch)
            ds *= 2

    self.middle_block = TimestepEmbedSequential(
        ResBlock(
            ch,
            time_embed_dim,
            dropout,
            dims=dims,
            use_checkpoint=use_checkpoint,
            use_scale_shift_norm=use_scale_shift_norm,
        ),
        AttentionBlock(ch,
                       use_checkpoint=use_checkpoint,
                       num_heads=num_heads,
                       num_head_channels=num_head_channels,
                       use_new_attention_order=use_new_attention_order,
        ),
        ResBlock(
            ch,
            time_embed_dim,
            dropout,
            dims=dims,
            use_checkpoint=use_checkpoint,
            use_scale_shift_norm=use_scale_shift_norm,
        ),
    )

    self.output_blocks = nn.ModuleList([])
    for level, mult in list(enumerate(channel_mult))[::-1]:
        for i in range(num_res_blocks + 1):
            layers = [
                ResBlock(
                    ch + input_block_chans.pop(),
                    time_embed_dim,
                    dropout,
                    out_channels=model_channels * mult,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                )
            ]
            ch = model_channels * mult
            if ds in attention_resolutions:
                layers.append(
                    AttentionBlock(
                        ch,
                        use_checkpoint=use_checkpoint,
                        num_heads=num_heads_upsample,
                        num_head_channels=num_head_channels,
                        use_new_attention_order=use_new_attention_order,
                    )
                )
            if level and i == num_res_blocks:
                layers.append(Upsample(ch, conv_resample, dims=dims))
                ds //= 2
            self.output_blocks.append(TimestepEmbedSequential(*layers))

    self.out = nn.Sequential(
        normalization(ch),
        SiLU(),
        zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
    )

def convert_to_fp16(self):
    """
    Convert the torso of the model to float16.
    """
    self.input_blocks.apply(convert_module_to_f16)
    self.middle_block.apply(convert_module_to_f16)
    self.output_blocks.apply(convert_module_to_f16)

def convert_to_fp32(self):
    """
    Convert the torso of the model to float32.
    """
    self.input_blocks.apply(convert_module_to_f32)
    self.middle_block.apply(convert_module_to_f32)
    self.output_blocks.apply(convert_module_to_f32)

@property
def inner_dtype(self):
    """
    Get the dtype used by the torso of the model.
    """
    return next(self.input_blocks.parameters()).dtype

def forward(self, x, timesteps, y=None, gt=None, **kwargs):
    """
    Apply the model to an input batch.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    """
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"

    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.inner_dtype)
    for module in self.input_blocks:
        h = module(h, emb)
        hs.append(h)
    h = self.middle_block(h, emb)
    for module in self.output_blocks:
        cat_in = th.cat([h, hs.pop()], dim=1)
        h = module(cat_in, emb)
    h = h.type(x.dtype)
    return self.out(h)

def get_feature_vectors(self, x, timesteps, y=None):
    """
    Apply the model and return all of the intermediate tensors.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: a dict with the following keys:
             - 'down': a list of hidden state tensors from downsampling.
             - 'middle': the tensor of the output of the lowest-resolution
                         block in the model.
             - 'up': a list of hidden state tensors from upsampling.
    """
    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)
    result = dict(down=[], up=[])
    h = x.type(self.inner_dtype)
    for module in self.input_blocks:
        h = module(h, emb)
        hs.append(h)
        result["down"].append(h.type(x.dtype))
    h = self.middle_block(h, emb)
    result["middle"] = h.type(x.dtype)
    for module in self.output_blocks:
        cat_in = th.cat([h, hs.pop()], dim=1)
        h = module(cat_in, emb)
        result["up"].append(h.type(x.dtype))
    return result

class SuperResModel(UNetModel): """ A UNetModel that performs super-resolution.

Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""

def __init__(self, image_size, in_channels, *args, **kwargs):
    super().__init__(image_size, in_channels * 2, *args, **kwargs)

def forward(self, x, timesteps, low_res=None, **kwargs):
    _, _, new_height, new_width = x.shape
    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
    x = th.cat([x, upsampled], dim=1)
    return super().forward(x, timesteps, **kwargs)

def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs):
    _, new_height, new_width, _ = x.shape
    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
    x = th.cat([x, upsampled], dim=1)
    return super().get_feature_vectors(x, timesteps, **kwargs)

`

seungwooham avatar Mar 05 '24 02:03 seungwooham