Focal-Loss-implement-on-Tensorflow
Focal-Loss-implement-on-Tensorflow copied to clipboard
∵0**0=1 ∴ gamma != 0
增加 gamma = 0 的情况的讨论
def focal_loss_sigmoid(prediction_tensor, target_tensor, weights=None, alpha=0.5, gamma=0):
target_tensor = tf.cast(target_tensor, tf.float32)
sigmoid_p = tf.nn.sigmoid(prediction_tensor)
zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)
pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - sigmoid_p, zeros)
neg_p_sub = array_ops.where(target_tensor > zeros, zeros, sigmoid_p)
if gamma != 0:
per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \
- (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0))
else:
# ∵0**0=1
per_entry_cross_ent = - alpha * target_tensor * tf.log(tf.clip_by_value(sigmoid_p, 1e-8, 1.0)) \
- (1 - alpha) * (1 - target_tensor) * tf.log(tf.clip_by_value(1.0 - sigmoid_p, 1e-8, 1.0))
return per_entry_cross_ent
此即为sigmoid_cross_entropy
Agreed. I also found that when gamma=0, it does not correctly compute the loss.