DiffKD icon indicating copy to clipboard operation
DiffKD copied to clipboard

About the loss

Open TungyuYoung opened this issue 1 year ago • 5 comments

Dear Hunto, I tried your method on my task. I found that while the original loss (student's output - ground truth) descended, the other losses, like autoencoder loss, diffusion loss seemed do not change. What do you think is the potential reason?

Best!

TungyuYoung avatar Mar 25 '24 05:03 TungyuYoung

Hi, you may check whether you have added the DiffKD module to your optimizer.

hunto avatar Mar 27 '24 01:03 hunto

Hello, I checked the code and found that it seems that I do have not add the DiffKD module to the optimizer. However, I checked the train.py you provided and it seems that there is no such operation, either. Have I missed something? Can you provide a solution idea? A million thanks! The diff code I created is below:

` feature_loss_func = FeatureLoss

class DiffKD(nn.Module): def init(self, student_channels, teacher_channels, kernel_size=3, inference_steps=5, num_train_time_steps=1000, use_ae=False, # use autoencoder ae_channels=None): super().init() self.use_ae = use_ae self.diffusion_inference_steps = inference_steps

    # AutoEncoder for compress teacher feature
    if use_ae:
        if ae_channels is None:
            ae_channels = teacher_channels // 2  # 16 * 2 * 2 = 64
        self.ae = AutoEncoder(teacher_channels, ae_channels)
        teacher_channels = ae_channels
    else:
        teacher_channels = teacher_channels
    # transform student feature to the same dimension as teacher
    self.trans = nn.Conv2d(student_channels, teacher_channels, 1)

    # diffusion model - predict noise
    self.scheduler = DDIMScheduler(num_train_time_steps=num_train_time_steps, clip_sample=False,
                                   beta_schedule="linear")
    self.noise_adapter = NoiseAdapter(teacher_channels, kernel_size)

    # pipeline for denoising student feature
    self.model = DiffusionModel(in_channels=teacher_channels, kernel_size=kernel_size)
    self.pipeline = DDIMPipeline(self.model, self.scheduler, self.noise_adapter)
    self.proj = nn.Sequential(nn.Conv2d(teacher_channels, teacher_channels, 1), nn.BatchNorm2d(teacher_channels))

def forward(self, student_feat, teacher_feat):
    # student_feat: [B, 16, T, F_c]
    student_feat = self.trans(student_feat)  # -> student_feat: [B, 32, T, F_c]

    if self.use_ae:
        hidden_teacher_feat, rec_teacher_feat = self.ae(teacher_feat)
        # rec_loss = F.mse_loss(teacher_feat, rec_teacher_feat)
        rec_loss = feature_loss_func(rec_teacher_feat, teacher_feat)
        teacher_feat = hidden_teacher_feat.detach()
    else:
        rec_loss = None

    # denoise student feature
    refined_feature = self.pipeline(
        batch_size=student_feat.shape[0],
        device=student_feat.device,
        dtype=student_feat.dtype,
        shape=student_feat.shape[1:],
        feat=student_feat,
        num_inference_steps=self.diffusion_inference_steps,
        proj=self.proj
    )
    refined_feature = self.proj(refined_feature)

    # train diffusion model
    ddim_loss = self.ddim_loss(teacher_feat)

    # Return: denoised student feature, teacher feature, diffusion loss, AutoEncoder loss
    return refined_feature, teacher_feat, ddim_loss, rec_loss

def ddim_loss(self, gt_feat):  # diffusion loss
    noise = torch.randn(gt_feat.shape, device=gt_feat.device)
    bs = gt_feat.shape[0]  # batch size

    # Sample a random timestep for each feature
    time_step = torch.randint(0, self.scheduler.num_train_time_steps, (bs,), device=gt_feat.device).long()

    # Add noise to the clean feature according to the noise magnitude at each timestep
    noisy_feature = self.scheduler.add_noise(gt_feat, noise, time_step)
    noisy_pred = self.model(noisy_feature, time_step)
    loss = F.mse_loss(noisy_pred, noise)
    return loss

class DiffPro(nn.Module): def init(self, student, teacher): super().init() self.student = student.train() self.teacher = teacher.eval()

    self.student_features = None
    self.teacher_features = None

    ae_channels = 32
    use_ae = True

    self.DiffusionProcess = nn.Sequential(
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(16, 64, kernel_size=1, use_ae=True, ae_channels=ae_channels),
        DiffKD(2, 2, kernel_size=1, use_ae=True, ae_channels=8),
    ).cuda()

def forward(self, noisy_specs):
    with torch.no_grad():
        teacher_features = self._feature_extractor(self.teacher, noisy_specs)

    student_features = self._feature_extractor(self.student, noisy_specs)

    refined_features = []
    teacher_out_features = []
    diff_loss = 0.
    ae_loss = 0.

    for i in range(len(self.DiffusionProcess)):
        refined_feature_, teacher_out_feature_, diff_loss_, ae_loss_ = self.DiffusionProcess[i](student_features[i],
                                                                                                teacher_features[i])
        refined_features.append(refined_feature_)
        teacher_out_features.append(teacher_out_feature_)
        diff_loss += diff_loss_
        ae_loss += ae_loss_

    return refined_features, teacher_out_features, diff_loss, ae_loss

def _feature_extractor(self, model, noisy_specs):
    feature_extractor = layer_feature_extraction.GTCRN_fe(model)
    features = []

    features_map = feature_extractor.extract_feature_maps(noisy_specs)

    encoder_f, decoder_f, dpgrnn1_f, dpgrnn2_f = (features_map["encoder"],
                                                  features_map["decoder"],
                                                  features_map["dpgrnn1"][0][0],
                                                  features_map["dpgrnn2"][0][0])

    for i in range(len(encoder_f)):
        features.append(encoder_f[i])
    features.append(dpgrnn1_f)
    features.append(dpgrnn2_f)
    for j in range(len(decoder_f)):
        features.append(decoder_f[j])

    feature_extractor.remove_hook()

    return features

`

TungyuYoung avatar Mar 27 '24 02:03 TungyuYoung

I added the module into the student via the following code: https://github.com/hunto/image_classification_sota/blob/6cb144105fc5c2f778e51cc66e35314938f96fae/lib/models/losses/kd_loss.py#L94

Sorry about this, I'll consider a better way to aachiev it.

hunto avatar Mar 27 '24 03:03 hunto

Understood. I will try to fix it. Thanks again!

TungyuYoung avatar Mar 27 '24 03:03 TungyuYoung

Dear Hunto, When I tried to save the trained student model, it seemed that the DIFF module would be saved together. Is there anyway to avoid it? Or do I need to utilize the DIFF module as well while inferencing?

TungyuYoung avatar Mar 29 '24 10:03 TungyuYoung