Focal-Loss-implement-on-Tensorflow icon indicating copy to clipboard operation
Focal-Loss-implement-on-Tensorflow copied to clipboard

what does array_ops.where mean?

Open DRosemei opened this issue 6 years ago • 1 comments

Thanks for your sharing. I want to know what array_ops.where means. Is it similar to tf.where? Besides, I use it for SSD-Tensorflow, and loss is very high. I make some changes.

sigmoid_p = tf.nn.sigmoid(prediction_tensor)
zeros = array_ops.zeros_like(sigmoid_p, dtype=sigmoid_p.dtype)

t = tf.one_hot(indices = target_tensor, depth = 4)
t_tensor = tf.cast(t, sigmoid_p.dtype)  

pos_p_sub = tf.where(t_tensor > zeros, t_tensor - sigmoid_p, zeros)

Could you tell me why?

DRosemei avatar May 14 '18 03:05 DRosemei

I‘m sorry for my late reply. tf.where actually is a wrapper for array_ops.where function. I don't know your target_tensor type, you just need replace your original loss function using the focal loss function.

ailias avatar Jun 11 '18 02:06 ailias