k-diffusion icon indicating copy to clipboard operation
k-diffusion copied to clipboard

Example for how to use the grow option?

Open Quasimondo opened this issue 2 years ago • 14 comments

I am trying to use the progressive growth option but I am getting an error when trying to use it as I think it is supposed to be used:

I have a trained 32x32 checkpoint which I am now trying to grow to a 64x64 one, so I am using the following arguments:

python3 train.py --config configs/config_64x64.json --name chkpt_64_1 --batch-size 100 --grow chkpt_32_2.pth --grow-config configs/config_32x32.json

The config_32x32.json is the default one from the repository, the config_64x64.json is using the additional layers and changed values as mentioned in #9:

#from 32x32:

 "model": {
        "type": "image_v1",
        "input_channels": 3,
        "input_size": [32, 32],
        "patch_size": 1,
        "mapping_out": 256,
        "depths": [2, 4, 4],
        "channels": [128, 256, 512],
        "self_attn_depths": [false, true, true],
        "dropout_rate": 0.05,
        "augment_prob": 0.12,
        "sigma_data": 0.5,
        "sigma_min": 1e-2,
        "sigma_max": 80,
        "sigma_sample_density": {
            "type": "lognormal",
            "mean": -1.2,
            "std": 1.2
        }
    },
    
#from 64x64:

"model": {
        "type": "image_v1",
        "input_channels": 3,
        "input_size": [64, 64],
        "patch_size": 1,
        "mapping_out": 256,
        "depths": [2, 2, 4, 4],
        "channels": [128, 256, 256, 512],
        "self_attn_depths": [false, false, true, true],
        "dropout_rate": 0.05,
        "augment_prob": 0.12,
        "sigma_data": 0.5,
        "sigma_min": 1e-2,
        "sigma_max": 80,
        "sigma_sample_density": {
            "type": "lognormal",
            "mean": -1.2,
            "std": 1.2
        }
    },

But when trying to run train.py I am getting a whole lot of "key missing" and "size mismatch" errors in inner_model.load_state_dict(old_inner_model.state_dict()) Missing key(s) in state_dict: "inner_model.u_net.d_blocks.1.2.main.0.mapper.weight", "inner_model.u_net.d_blocks.1.2.main.0.mapper.bias", "inner_model.u_net.d_blocks.1.2.main.2.weight", "inner_model.u_net.d_blocks.1.2.main.2.bias", "inn....

So I am wondering whether I am doing something wrong here or if this just one of those "work in progress" issues.

Is suspect that I might rather have to do something that involves patch_size and skip_stages since those are used in the wrapper, but I have no idea what their function is.

Quasimondo avatar Jul 29 '22 14:07 Quasimondo

Right now you have to train a model that has the U-Net stages you are going to train later already in it, this ensures that when you grow the model all of the shapes will be compatible with the smaller model. You do this by putting "skip_stages": 1 or some other value in what you have for the the 64x64 config and setting "input_size": [32, 32]. This makes it skip the outer stage of the U-Net during training and inference and so you can train at 32x32 first and then change skip_stages to 0 to train at 64x64. The same for any number of stages, you just have to decide in advance how far you are going to grow the model right now.

Note that this doesn't apply to growing by increasing the patch size, that doesn't add U-Net stages so you can do it without deciding what your final patch size is going to be.

crowsonkb avatar Jul 29 '22 15:07 crowsonkb

If someone has a good idea on how to reasonably progressively grow models without creating all of the U-Net stages first, in a way that ensures the shapes are compatible, I'd like to hear it :)

crowsonkb avatar Jul 29 '22 15:07 crowsonkb

Ah thanks for the swift explanation! Of course now it makes total sense.

My naive approach for growing without creating the stages first would be to try to do a "smart copy" of weights from the smaller model into modules of the bigger one that have the same size and relative level and just leave the missing ones with random init values, but I fear that's probably not the way to do it.

Quasimondo avatar Jul 29 '22 15:07 Quasimondo

If someone has a good idea on how to reasonably progressively grow models without creating all of the U-Net stages first, in a way that ensures the shapes are compatible, I'd like to hear it :)

This has probably been already tried, but what are your thoughts on using inpainting to increase the resolution? An example would be something like this (1) generate a 32x32 image (2) crop it into 4 equal squares of size 16x16 (3) for each of those 4 squares double the resolution to 32x32, mask the new unknown pixels and use inpainting to generate them (5) stitch together the 4 images to the 64x64 one. Inpainting is a bit slow, but it would not require additional UNets.

flavioschneider avatar Jul 29 '22 16:07 flavioschneider

My gut feeling is that the 32x32 model does not really know about the finer details that should be present in a 64x64 model so I would expect the result to be rather on the blurry pixelly side, except maybe if you train it on a large range of differently scaled details (so it has for example learned how the closeup of an eye looks and not just two black pixels in a face).

Quasimondo avatar Jul 29 '22 16:07 Quasimondo

Pictures in the dataset are usually not taken at the same zoom level, I think the model should know how something looks if it's "closer" in order to inpaint it correctly. But yes, you are probably right that if we abuse this process too much (e.g. by repeating it recursively) something weird will happen. For example, zooming on finer details of a face that are not present in any image of the dataset.

flavioschneider avatar Jul 29 '22 16:07 flavioschneider

Oh yes that's a possibility of course. I have only just started diving into training my own diffusion models, but one observation with my toy models I made is that the old rule that well-aligned datasets of similar things converge better than those that are zoom-level or composition-wise all over the place still applies. Which is why right now I try to keep my data within a certain theme or scale level (and which is why I try to use the grow method).

Quasimondo avatar Jul 29 '22 17:07 Quasimondo

Just FYI - I was curious to see what happens if I map the most likely weights of the 32 model to the 64 model and as it turns out there are actually just 6 weight tensors in the smaller model that have no match in the bigger one - all the others can be copied.

And whilst the first sample is just noise, the model training seems to recover quite quickly and at least from what I can see so far does not show any weird behavior. But maybe too early to tell.

In case you are interested to look deeper into this option I can share my (super-hacky) code snippets somewhere.

Quasimondo avatar Jul 29 '22 18:07 Quasimondo

I'm interested! That should usually work, the thing I am specifically worried about is when you try to do progressive growing where you change the base channel count, say if you have [256, 256, 512] for your 32x32 and then you want to do [128, 256, 256, 512] for your 64x64. Then the first residual block of the 32x32 model has an input channel count of 256 and an output channel count of 256, but in the bigger model the corresponding residual block has an input channel count of 128 (because it now has a 64x64, 128 channel stage before it). So you have to drop that residual block and replace it with a new randomly inited one w/ the correct shapes. This is why I went with creating the entire model beforehand and changing the number of stages you skip, it will create the first residual block of the 32x32 stage with 128 input channels because it knows you're going to add a 128-channel 64x64 stage later.

crowsonkb avatar Jul 29 '22 20:07 crowsonkb

Yeah, I guess the question might be if the model learns bad superficial habits that way and rather uses the skip connections whilst neglecting the "deeper" layers that likely have more global knowledge - hard to tell with my small training set and at that size.

Here is a csv for the 32 to 64 model mapping mapping_32_to_64.csv

The patched part of the code in train.py looks like this:

if args.grow:
        if not args.grow_config:
            raise ValueError('--grow requires --grow-config')
        ckpt = torch.load(args.grow, map_location='cpu')
        old_config = K.config.load_config(open(args.grow_config))
        old_inner_model = K.config.make_model(old_config)
        old_inner_model.load_state_dict(ckpt['model_ema'])
        
        if old_config['model']['skip_stages'] != model_config['skip_stages']:
            old_inner_model.set_skip_stages(model_config['skip_stages'])
        if old_config['model']['patch_size'] != model_config['patch_size']:
            old_inner_model.set_patch_size(model_config['patch_size'])
           
        mapping_lut = open("mapping_32_to_64.csv","r").read().split("\n")
        
        old_dict = old_inner_model.state_dict()
        new_dict = inner_model.state_dict()
        
        for line in mapping_lut:
            oldnew = line.split(",")
            if len(oldnew)==2:
                new_dict[oldnew[1]] = old_dict[oldnew[0]]
        
        inner_model.load_state_dict(new_dict)    
        
        del old_dict, new_dict, mapping_lut
        del ckpt, old_inner_model

The two conf files and the args are the ones I used at the top of this thread.

Quasimondo avatar Jul 29 '22 20:07 Quasimondo

You do this by putting "skip_stages": 1 or some other value in what you have for the the 64x64 config and setting "input_size": [32, 32].

Somehow I'm not understanding what this means, maybe because I'm still not very familiar with U-Net architecture. Are you saying you set input_size to [32, 32] for both configs, then just add that skip_stages to the larger model? Is it implicitly doubling the size, then, or is that intended [64, 64] size specified somewhere else. Any chance you could provide an example conifg?

kjhenner avatar Jul 29 '22 20:07 kjhenner

From what I have understood, you first have to decide what the maximum size is you want to train for and create that config. So for 128 that would be something like:

#conf128.json
"model": {
      "type": "image_v1",
      "input_channels": 3,
      "input_size": [128, 128],
      "skip_stages":0,
      "mapping_out": 256,
      "depths": [2, 2, 2, 4, 4],
      "channels": [128, 256, 256, 512, 512],
      "self_attn_depths": [false, false, false, true, true],
      "dropout_rate": 0.05,
      "augment_prob": 0.12,
      "sigma_data": 0.5,
      "sigma_min": 1e-2,
      "sigma_max": 160,
      "sigma_sample_density": {
          "type": "lognormal",
          "mean": -1.2,
          "std": 1.2
      }
  },.....

Now you make a copy of that conf and for the first stage (assuming you start with 32x32) you change the values in the copy to:

#conf_32x32_skip.json
"model": {
      "type": "image_v1",
      "input_channels": 3,
      "input_size": [32, 32],
      "skip_stages":2,
      "mapping_out": 256,
      "depths": [2, 2, 2, 4, 4],
      "channels": [128, 256, 256, 512, 512],
      "self_attn_depths": [false, false, false, true, true],
      "dropout_rate": 0.05,
      "augment_prob": 0.12,
      "sigma_data": 0.5,
      "sigma_min": 1e-2,
      "sigma_max": 160,
      "sigma_sample_density": {
          "type": "lognormal",
          "mean": -1.2,
          "std": 1.2
      }
  },....

(I don't know if the sigma_max value has to be reduced here to 80?)

In the first stage you do not use the grow argument yet. When the 32x32 model has finished training you create another conf for the 64x64 step:

#conf_64x64_skip.json
 "model": {
        "type": "image_v1",
        "input_channels": 3,
        "input_size": [64, 64],
        "skip_stages":1,
        "mapping_out": 256,
        "depths": [2, 2, 2, 4, 4],
        "channels": [128, 256, 256, 512, 512],
        "self_attn_depths": [false, false, false, true, true],
        "dropout_rate": 0.05,
        "augment_prob": 0.12,
        "sigma_data": 0.5,
        "sigma_min": 1e-2,
        "sigma_max": 160,
        "sigma_sample_density": {
            "type": "lognormal",
            "mean": -1.2,
            "std": 1.2
        }
    },....

This time you have to use the -grow argument:

python3 train.py --config configs/config_64x64_skip.json --name chkpt_64 --grow chkpt_32.pth --grow-config configs/config_32x32_skip.json

And once that has finished you can use the 128 conf:

python3 train.py --config configs/config_128.json --name chkpt_128 --grow chkpt_64.pth --grow-config configs/config_64x64_skip.json

Quasimondo avatar Jul 29 '22 21:07 Quasimondo

Trying to grow to a 256 model I am wondering what the deliberations are when adding channels in order to grow the size. Looking at the progression for 32 to 128 I see: 32 -> [128, 256, 512] 64 -> [128, 256, 256, 512] 128 -> [128, 256, 256, 512, 512]

So for 256 I see 3 possible options, but I am not sure if this is just try-and-error or if there are certain pros-and-cons 256 -> [128, 128, 256, 256, 512, 512] 256 -> [128, 256, 256, 256, 512, 512] 256 -> [128, 256, 256, 512, 512, 512]

My guess is that using less channels will require less memory, allow for a larger batch size and make training faster at the cost of expressiveness?

[Note: I just saw that the UNet2DModel from huggingface/diffusers is using the 128, 128, 256, 256, 512, 512 variant, so I guess i try that one first]

[Note 2: looks like 24GB GPU memory are not sufficient to train a 256 model even with a batch size of 1 :-(]

Quasimondo avatar Jul 30 '22 07:07 Quasimondo

A little tip for those who are not interested in the evaluation data: the arg --evaluate-every -1 skips that step and saves some time. In addition to that it is possible to skip the whole feature evaluation preparation then (and what's most important, free some GPU memory - which allows for a larger batch size in turn):

in train.py:

#if we do not use evaluation we do not need the features:
if args.evaluate_every > 0:
        extractor = K.evaluation.InceptionV3FeatureExtractor(device=device)
        train_iter = iter(train_dl)
        if accelerator.is_main_process:
            print('Computing features for reals...')
        reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size)
        if accelerator.is_main_process:
            metrics_log_filepath = Path(f'{args.name}_metrics.csv')
            if metrics_log_filepath.exists():
                metrics_log_file = open(metrics_log_filepath, 'a')
            else:
                metrics_log_file = open(metrics_log_filepath, 'w')
                print('step', 'fid', 'kid', sep=',', file=metrics_log_file, flush=True)
        del train_iter

Quasimondo avatar Jul 30 '22 10:07 Quasimondo