LeBA icon indicating copy to clipboard operation
LeBA copied to clipboard

Problem about the loss

Open FeiiYin opened this issue 3 years ago • 0 comments

Hi, Yang:

I've read your Leba work and I'm very interested in it. Recently, I'm trying to reproduce the part that uses forward loss and backward loss to narrow the gap between the two models. However, I feel confused about the code. I hope you can answer it. Thank you very much

I imitate the code in Git and try to calculate the backward loss as follows:

images = images.detach().clone()
adv_images = images + diff

surrogate_logits = surrogate_model(images)
surrogate_loss = nn.CrossEntropyLoss(reduction='none')(surrogate_logits, labels)

grad = torch.autograd.grad(surrogate_loss.sum(), images, create_graph=True)[0]
s_loss = (diff.detach() * grad).view([images.shape[0], -1]).sum(dim=1)  # scalar

target_adv_logits = target_model(adv_images)
target_adv_loss = nn.CrossEntropyLoss(reduction='none')(target_adv_logits, labels)

target_ori_logits = target_model(images)
target_ori_loss = nn.CrossEntropyLoss(reduction='none')(target_ori_logits, labels)
d_loss = torch.log(target_adv_loss) - torch.log(target_ori_loss)  # scalar

backward_loss = nn.MSELoss()(s_loss / lamda, d_loss.detach())

However, based on the above implementation, the gap between the output of the surrogate model and the target model gradually widens uncontrollably. I tested the gap via torch.nn.MSELoss(reduce=True, size_average=False)(target_model(images), surrogate_model(images)).

Whats's more, I imitate the code in Git and try to calculate the forward loss as follows:

surrogate_logits = surrogate_model(images)
surrogate_prob = torch.nn.functional.softmax(surrogate_logits, dim=1)
s_score = surrogate_prob.gather(1, labels.reshape([-1, 1]))

target_logits = target_model(images)
target_prob = torch.nn.functional.softmax(target_logits, dim=1)
target_score = target_prob.gather(1, labels.reshape([-1, 1]))

forward_loss = nn.MSELoss()(s_score, target_score.detach())

And using forward loss the gap does not show a monotonous downward trend. I would like to ask which part of my understanding is wrong. :(

Yours.

Fei

FeiiYin avatar Jan 08 '21 02:01 FeiiYin