loss_function_search
loss_function_search copied to clipboard
Some quesrions
What great work! Thank you very much for publishing your code, and I got into trouble when reproducing your code. Could you explain the meaning of variables "p_bins" and "a"? Thank you very much. ` def my_loss(x, lb, p_bins, a, sm, search_type):
# my_loss(outputs, lb, p_bins, a, sm, search_type)
batch_size = x.shape[0]
new_x = 1.0 * x
if search_type == 'global':
if a[0] <= 0:
b = 1.0 - a[0] * math.exp(sm / 3)
else:
b = 1.0
gt = x[torch.arange(batch_size), lb]
new_x[torch.arange(batch_size), lb] = gt / (a[0] * math.exp(sm / 3) * gt + b)
elif search_type == 'local':
for i in range(batch_size):
for j in range(len(p_bins) - 1):
if x[i, lb[i]].item() <= p_bins[j + 1]:
if a[j] <= 0:
b = 1.0 - a[j] * math.exp(sm / 2)
else:
b = 1.0
new_x[i, lb[i]] = x[i, lb[i]] / (a[j] * math.exp(sm / 2) * x[i, lb[i]] + b)
break
else:
raise Exception('Unknown search type!')
return new_x`