Focal-Loss-implement-on-Tensorflow
Focal-Loss-implement-on-Tensorflow copied to clipboard
what does array_ops.where mean?
Thanks for your sharing. I want to know what array_ops.where means. Is it similar to tf.where? Besides, I use it for SSD-Tensorflow, and loss is very high. I make some changes.
sigmoid_p = tf.nn.sigmoid(prediction_tensor)
zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)
t = tf.one_hot(indices = target_tensor, depth = 4)
t_tensor = tf.cast(t, sigmoid_p.dtype)
pos_p_sub = tf.where(t_tensor > zeros, t_tensor - sigmoid_p, zeros)
Could you tell me why?
I‘m sorry for my late reply. tf.where actually is a wrapper for array_ops.where function. I don't know your target_tensor type, you just need replace your original loss function using the focal loss function.