ControlNet icon indicating copy to clipboard operation
ControlNet copied to clipboard

adding more "hint" to training process

Open orydatadudes opened this issue 1 year ago • 14 comments

Hi, i was focusing with the human posture task (getting posture from openpose image + prompt and than generating the charter under the right pose - control_sd15_openpose.pth)

However, i wanted to add one more hint to force the controlnet to generate specific human: so if in the original code the hint be an posture image like that :

v2-c5e272899550ac318ed4732336fd7c82_720w

i would like to add more image of the specific human:

MEN_Denim_id_00000080_0_01_7_additional

the target should be that image of that person, under the new posture

so what i did is:

  1. in the dataset file: reading that extra image too, concatenate in the channel dimension, that image with the posture image so now the
    source variable is 6 channels not 3

     # concate source and source image
     source = np.concatenate([source,source_image],axis=2)
    
     return dict(jpg=target, txt=prompt, hint=source)
    
  2. changing the yaml config file to support 6 channels - NOT SURE I REALLY UNDERSTATED THE MEANING OF THESE VALUES

model: target: cldm.cldm.ControlLDM params: linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: "jpg" cond_stage_key: "txt" control_key: "hint" image_size: 64 channels: was 4 i changed to 7 cond_stage_trainable: false conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 use_ema: False only_mid_control: False

control_stage_config:
  target: cldm.cldm.ControlNet
  params:
    image_size: 32 # unused
    **in_channels:  was 4 i changed to 7**
    **hint_channels: was 3 i changed to 6** 
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_heads: 8
    use_spatial_transformer: True
    transformer_depth: 1
    context_dim: 768
    use_checkpoint: True
    legacy: False

unet_config:
  target: cldm.cldm.ControlledUnetModel
  params:
    image_size: 32 # unused
    **in_channels:  was 4 i changed to 7**
    **out_channels:  was 4 i changed to 7**
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_heads: 8
    use_spatial_transformer: True
    transformer_depth: 1
    context_dim: 768
    use_checkpoint: True
    legacy: False

first_stage_config:
  target: ldm.models.autoencoder.AutoencoderKL
  params:
    **embed_dim:  was 4 i changed to 7**
    monitor: val/rec_loss
    ddconfig:
      double_z: true
      **z_channels:  was 4 i changed to 7**
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
    lossconfig:
      target: torch.nn.Identity

cond_stage_config:
  target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

the problem is when i trained the model from scratch - running tutorial_train.py with resume_path = None the model predictions, the reconstruction and the samples that locate under image_log->train folder are just a noise

does anyone have any idea how to solve that ? thanks

orydatadudes avatar Mar 14 '23 10:03 orydatadudes

I don't think you need to start training from scratch. If you finetuning from some trained model, you should be able to converge faster

liuzhihui2046 avatar Mar 16 '23 06:03 liuzhihui2046

To my knowledge, you should probably alter hint_channels under control_stage_config instead of channels

fkcptlst avatar Apr 09 '23 04:04 fkcptlst

Hi, do you know why use_ema is set to False?

SnowdenLee avatar Apr 24 '23 08:04 SnowdenLee

hello!

I'm doing similar attempt as u do.

do u have any further results?

Yuhyeong avatar Sep 19 '23 10:09 Yuhyeong

hello!

I'm doing similar attempt as u do.

do u have any further results?

Hello!

Have you solved the problem? I wonder if I could learn from your work.

Thank you

MuyuenHoshino avatar Jan 03 '24 09:01 MuyuenHoshino

hello! I'm doing similar attempt as u do. do u have any further results?

Hello!

Have you solved the problem? I wonder if I could learn from your work.

Thank you

I finshed my works months ago, it works but not significantly effective.

In the config part, i only changed hint_channels to 6

Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

Yuhyeong avatar Jan 03 '24 10:01 Yuhyeong

hello! I'm doing similar attempt as u do. do u have any further results?

Hello! Have you solved the problem? I wonder if I could learn from your work. Thank you

I finshed my works months ago, it works but not significantly effective.

In the config part, i only changed hint_channels to 6

Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

Thanks for your prompt reply, I am trying to write as you say. I failed to save the merged img as tiff, so I
use numpy.concatenate to merge two imgs, like

    stacked_array = np.concatenate((inpaint_resize, ref_image), axis=2)
    #inpaint_resize : (512,512,3)
    #ref_image : (512,512,3)

Then I get a (512,512,6) numpy as a hint. But there is something wrong.

  File "/root/autodl-tmp/ControlNet-v1-1-nightly/cldm/logger.py", line 40, in log_local
    Image.fromarray(grid).save(path)
  File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, in fromarray
    raise TypeError(msg) from e
TypeError: Cannot handle this data type: (1, 1, 6), |u1

I am trying to fix this problem. May I ask if you have done any operation other than modifying the hint_channels, or can you provide the part that you save tiff? I would be very, very grateful.

MuyuenHoshino avatar Jan 04 '24 13:01 MuyuenHoshino

hello! I'm doing similar attempt as u do. do u have any further results?

Hello! Have you solved the problem? I wonder if I could learn from your work. Thank you

I finshed my works months ago, it works but not significantly effective. In the config part, i only changed hint_channels to 6 Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

Thanks for your prompt reply, I am trying to write as you say. I failed to save the merged img as tiff, so I use numpy.concatenate to merge two imgs, like

    stacked_array = np.concatenate((inpaint_resize, ref_image), axis=2)
    #inpaint_resize : (512,512,3)
    #ref_image : (512,512,3)

Then I get a (512,512,6) numpy as a hint. But there is something wrong.

  File "/root/autodl-tmp/ControlNet-v1-1-nightly/cldm/logger.py", line 40, in log_local
    Image.fromarray(grid).save(path)
  File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, in fromarray
    raise TypeError(msg) from e
TypeError: Cannot handle this data type: (1, 1, 6), |u1

I am trying to fix this problem. May I ask if you have done any operation other than modifying the hint_channels, or can you provide the part that you save tiff? I would be very, very grateful.

it is a meanless error, just overlook it for is aims to record the image log while training.

            if grid.shape[2] == 6:
                grid = grid[ :, :,:3]
                continue

add this code before File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, to skip it in cldm/logger.py

Yuhyeong avatar Jan 05 '24 02:01 Yuhyeong

hello! I'm doing similar attempt as u do. do u have any further results?

Hello! Have you solved the problem? I wonder if I could learn from your work. Thank you

I finshed my works months ago, it works but not significantly effective. In the config part, i only changed hint_channels to 6 Then I merged 2 3channels img into a 6channels img, and save as tiff, create a customized dataset object for training. this is my dataset code below.

class MyDataset(Dataset):
    def __init__(self):
        self.data = []
        with open('./training/pose+face/prompt.json', 'rt') as f:
            for line in f:
                self.data.append(json.loads(line))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        source_filename = item['source']
        target_filename = item['target']
        prompt = item['prompt']


        source = tifffile.imread(os.path.join('./training/pose+face/source', source_filename))
        target = cv2.imread(os.path.join('./training/pose+face/target', target_filename))

        # Do not forget that OpenCV read images in BGR order.
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)

        # Normalize source images to [0, 1].
        source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)

Thanks for your prompt reply, I am trying to write as you say. I failed to save the merged img as tiff, so I use numpy.concatenate to merge two imgs, like

    stacked_array = np.concatenate((inpaint_resize, ref_image), axis=2)
    #inpaint_resize : (512,512,3)
    #ref_image : (512,512,3)

Then I get a (512,512,6) numpy as a hint. But there is something wrong.

  File "/root/autodl-tmp/ControlNet-v1-1-nightly/cldm/logger.py", line 40, in log_local
    Image.fromarray(grid).save(path)
  File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, in fromarray
    raise TypeError(msg) from e
TypeError: Cannot handle this data type: (1, 1, 6), |u1

I am trying to fix this problem. May I ask if you have done any operation other than modifying the hint_channels, or can you provide the part that you save tiff? I would be very, very grateful.

it is a meanless error, just overlook it for is aims to record the image log while training.

            if grid.shape[2] == 6:
                grid = grid[ :, :,:3]
                continue

add this code before File "/root/miniconda3/lib/python3.8/site-packages/PIL/Image.py", line 3102, to skip it in cldm/logger.py

It makes sense! Thanks for your useful advice!

MuyuenHoshino avatar Jan 06 '24 10:01 MuyuenHoshino

Hey guys, I also wants to use multiple controls for my thesis controlnet. Could you add the complete yaml config and also MyDataset Class for this problem? If I get some result I will add mine

SamanFekri avatar Feb 03 '24 16:02 SamanFekri

I change model config to this:

model:
  target: cldm.scldm.ExtendedControlLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4 # changed from 4 to 7
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        image_size: 32 # unused
        in_channels: 4
        hint_channels: 9 # 3 for one image, 2 * 3 for 2 image hint, 3 * 3 for 3 images
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

I add the hints inside my Dataset class to load data from It and I change the DataSet class and concatenate different images to each other as a source

        source = cv2.imread(f'{self.dataset_path}/{source_filename}')
        target = cv2.imread(f'{self.dataset_path}/{target_filename}')

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
        
        # Add canny detector as second control
        detected_map = resize_image(HWC3(target), self.resolution)
        detected_map = self.apply_canny(detected_map, self.canny_low, self.canny_high)
        canny = HWC3(detected_map)
        
        # Resize the original Image
        resize = cv2.resize(target, self.small_dim, interpolation = cv2.INTER_AREA)
        resize = cv2.resize(resize, self.original_dim, interpolation = cv2.INTER_AREA)

        
        # concat the channels to source
        source = np.concatenate((source, resize, canny), axis=2)
        
        # Normalize source images to [0, 1].
        # source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

I started from a pretrained stable diffusion model. I need to load the weights inside the model hence I duplicate the weights in the hints and you can see it in the following code:

# Model Creation
model = create_model(config['model']['config_file']).cpu()

lsd = load_state_dict(resume_path, location='cpu')

# Convert the list of tensors to a single tensor
repeated_tensor = torch.stack([torch.tensor(item).repeat(1, config['model']['num_hints'], 1, 1) for item in lsd['control_model.input_hint_block.0.weight']]).squeeze(1)

# Assign the corrected tensor to the state dictionary
lsd['control_model.input_hint_block.0.weight'] = repeated_tensor

In the above code the config['model']['num_hints']=3

SamanFekri avatar Feb 16 '24 16:02 SamanFekri

I change model config to this:

model:
  target: cldm.scldm.ExtendedControlLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4 # changed from 4 to 7
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        image_size: 32 # unused
        in_channels: 4
        hint_channels: 9 # 3 for one image, 2 * 3 for 2 image hint, 3 * 3 for 3 images
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

I add the hints inside my Dataset class to load data from It and I change the DataSet class and concatenate different images to each other as a source

        source = cv2.imread(f'{self.dataset_path}/{source_filename}')
        target = cv2.imread(f'{self.dataset_path}/{target_filename}')

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
        
        # Add canny detector as second control
        detected_map = resize_image(HWC3(target), self.resolution)
        detected_map = self.apply_canny(detected_map, self.canny_low, self.canny_high)
        canny = HWC3(detected_map)
        
        # Resize the original Image
        resize = cv2.resize(target, self.small_dim, interpolation = cv2.INTER_AREA)
        resize = cv2.resize(resize, self.original_dim, interpolation = cv2.INTER_AREA)

        
        # concat the channels to source
        source = np.concatenate((source, resize, canny), axis=2)
        
        # Normalize source images to [0, 1].
        # source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

I started from a pretrained stable diffusion model. I need to load the weights inside the model hence I duplicate the weights in the hints and you can see it in the following code:

# Model Creation
model = create_model(config['model']['config_file']).cpu()

lsd = load_state_dict(resume_path, location='cpu')

# Convert the list of tensors to a single tensor
repeated_tensor = torch.stack([torch.tensor(item).repeat(1, config['model']['num_hints'], 1, 1) for item in lsd['control_model.input_hint_block.0.weight']]).squeeze(1)

# Assign the corrected tensor to the state dictionary
lsd['control_model.input_hint_block.0.weight'] = repeated_tensor

In the above code the config['model']['num_hints']=3

Thank you for the inputs, may I know if you were able to get good results with this?

bhosalems avatar Apr 18 '24 16:04 bhosalems

I change model config to this:

model:
  target: cldm.scldm.ExtendedControlLDM
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    control_key: "hint"
    image_size: 64
    channels: 4 # changed from 4 to 7
    cond_stage_trainable: false
    conditioning_key: crossattn
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    use_ema: False
    only_mid_control: False

    control_stage_config:
      target: cldm.cldm.ControlNet
      params:
        image_size: 32 # unused
        in_channels: 4
        hint_channels: 9 # 3 for one image, 2 * 3 for 2 image hint, 3 * 3 for 3 images
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    unet_config:
      target: cldm.cldm.ControlledUnetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

I add the hints inside my Dataset class to load data from It and I change the DataSet class and concatenate different images to each other as a source

        source = cv2.imread(f'{self.dataset_path}/{source_filename}')
        target = cv2.imread(f'{self.dataset_path}/{target_filename}')

        # Do not forget that OpenCV read images in BGR order.
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)
        
        # Add canny detector as second control
        detected_map = resize_image(HWC3(target), self.resolution)
        detected_map = self.apply_canny(detected_map, self.canny_low, self.canny_high)
        canny = HWC3(detected_map)
        
        # Resize the original Image
        resize = cv2.resize(target, self.small_dim, interpolation = cv2.INTER_AREA)
        resize = cv2.resize(resize, self.original_dim, interpolation = cv2.INTER_AREA)

        
        # concat the channels to source
        source = np.concatenate((source, resize, canny), axis=2)
        
        # Normalize source images to [0, 1].
        # source = np.transpose(source, (1, 2, 0))
        source = source.astype(np.float32) / 255.0

I started from a pretrained stable diffusion model. I need to load the weights inside the model hence I duplicate the weights in the hints and you can see it in the following code:

# Model Creation
model = create_model(config['model']['config_file']).cpu()

lsd = load_state_dict(resume_path, location='cpu')

# Convert the list of tensors to a single tensor
repeated_tensor = torch.stack([torch.tensor(item).repeat(1, config['model']['num_hints'], 1, 1) for item in lsd['control_model.input_hint_block.0.weight']]).squeeze(1)

# Assign the corrected tensor to the state dictionary
lsd['control_model.input_hint_block.0.weight'] = repeated_tensor

In the above code the config['model']['num_hints']=3

Thank you for the inputs, may I know if you were able to get good results with this?

Maybe we can't achieve the desired result. I tried segmap plus depth... If there are no other bugs in my experiment, then the conclusion is: the image is not ok~ Can you train two control volumes and get good results?

jiachen0212 avatar Aug 07 '24 08:08 jiachen0212