PnPInversion
PnPInversion copied to clipboard
why the `structure_distance_metric_calculator` only use `calculate_global_ssim_loss` ?
This is the calculation code of structure_distance:
def calculate_structure_distance(self, img_pred, img_gt, mask_pred=None, mask_gt=None, use_gpu = True):
img_pred = np.array(img_pred).astype(np.float32)
img_gt = np.array(img_gt).astype(np.float32)
assert img_pred.shape == img_gt.shape, "Image shapes should be the same."
if mask_pred is not None:
mask_pred = np.array(mask_pred).astype(np.float32)
img_pred = img_pred * mask_pred
if mask_gt is not None:
mask_gt = np.array(mask_gt).astype(np.float32)
img_gt = img_gt * mask_gt
img_pred = torch.from_numpy(np.transpose(img_pred, axes=(2, 0, 1))).to(self.device)
img_gt = torch.from_numpy(np.transpose(img_gt, axes=(2, 0, 1))).to(self.device)
img_pred = torch.unsqueeze(img_pred, 0)
img_gt = torch.unsqueeze(img_gt, 0)
structure_distance = self.structure_distance_metric_calculator.calculate_global_ssim_loss(img_gt, img_pred)
return structure_distance.data.cpu().numpy()
As we can see, it calls the calculate_global_ssim_loss method instead of forward.
This is the code of structure_distance_metric_calculator class instance:
class LossG(torch.nn.Module):
def __init__(self, cfg,device):
super().__init__()
self.cfg = cfg
self.device=device
self.extractor = VitExtractor(model_name=cfg['dino_model_name'], device=device)
imagenet_norm = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
global_resize_transform = Resize(cfg['dino_global_patch_size'], max_size=480)
self.global_transform = transforms.Compose([global_resize_transform,
imagenet_norm
])
self.lambdas = dict(
lambda_global_cls=cfg['lambda_global_cls'],
lambda_global_ssim=0,
lambda_entire_ssim=0,
lambda_entire_cls=0,
lambda_global_identity=0
)
def update_lambda_config(self, step):
if step == self.cfg['cls_warmup']:
self.lambdas['lambda_global_ssim'] = self.cfg['lambda_global_ssim']
self.lambdas['lambda_global_identity'] = self.cfg['lambda_global_identity']
if step % self.cfg['entire_A_every'] == 0:
self.lambdas['lambda_entire_ssim'] = self.cfg['lambda_entire_ssim']
self.lambdas['lambda_entire_cls'] = self.cfg['lambda_entire_cls']
else:
self.lambdas['lambda_entire_ssim'] = 0
self.lambdas['lambda_entire_cls'] = 0
def forward(self, outputs, inputs):
self.update_lambda_config(inputs['step'])
losses = {}
loss_G = 0
if self.lambdas['lambda_global_ssim'] > 0:
losses['loss_global_ssim'] = self.calculate_global_ssim_loss(outputs['x_global'], inputs['A_global'])
loss_G += losses['loss_global_ssim'] * self.lambdas['lambda_global_ssim']
if self.lambdas['lambda_entire_ssim'] > 0:
losses['loss_entire_ssim'] = self.calculate_global_ssim_loss(outputs['x_entire'], inputs['A'])
loss_G += losses['loss_entire_ssim'] * self.lambdas['lambda_entire_ssim']
if self.lambdas['lambda_entire_cls'] > 0:
losses['loss_entire_cls'] = self.calculate_crop_cls_loss(outputs['x_entire'], inputs['B_global'])
loss_G += losses['loss_entire_cls'] * self.lambdas['lambda_entire_cls']
if self.lambdas['lambda_global_cls'] > 0:
losses['loss_global_cls'] = self.calculate_crop_cls_loss(outputs['x_global'], inputs['B_global'])
loss_G += losses['loss_global_cls'] * self.lambdas['lambda_global_cls']
if self.lambdas['lambda_global_identity'] > 0:
losses['loss_global_id_B'] = self.calculate_global_id_loss(outputs['y_global'], inputs['B_global'])
loss_G += losses['loss_global_id_B'] * self.lambdas['lambda_global_identity']
losses['loss'] = loss_G
return losses
def calculate_global_ssim_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(inputs, outputs): # avoid memory limitations
a = self.global_transform(a)
b = self.global_transform(b)
with torch.no_grad():
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
loss += F.mse_loss(keys_ssim, target_keys_self_sim)
return loss
def calculate_crop_cls_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(outputs, inputs): # avoid memory limitations
a = self.global_transform(a).unsqueeze(0).to(self.device)
b = self.global_transform(b).unsqueeze(0).to(self.device)
cls_token = self.extractor.get_feature_from_input(a)[-1][0, 0, :]
with torch.no_grad():
target_cls_token = self.extractor.get_feature_from_input(b)[-1][0, 0, :]
loss += F.mse_loss(cls_token, target_cls_token)
return loss
def calculate_global_id_loss(self, outputs, inputs):
loss = 0.0
for a, b in zip(inputs, outputs):
a = self.global_transform(a)
b = self.global_transform(b)
with torch.no_grad():
keys_a = self.extractor.get_keys_from_input(a.unsqueeze(0), 11)
keys_b = self.extractor.get_keys_from_input(b.unsqueeze(0), 11)
loss += F.mse_loss(keys_a, keys_b)
return loss