TensorFlow.jl icon indicating copy to clipboard operation
TensorFlow.jl copied to clipboard

trouble using tf.where

Open qvdgeer opened this issue 6 years ago • 3 comments

Hi, trying to get a toy example with tf.where working using TensorFlow

sess = TensorFlow.Session()

A = TensorFlow.Variable([-1,-1,-1,-1,0,-1]) B = TensorFlow.Variable([0,0,0,0,1,0])

mask = TensorFlow.equal(A,0)

C= TensorFlow.where(mask,B,A)

run(sess, TensorFlow.global_variables_initializer()) run(sess,[A,B,mask,C])

but I get this

MethodError: no method matching where(::TensorFlow.Tensor{Bool}, ::TensorFlow.Variables.Variable{Float64}, ::TensorFlow.Variables.Variable{Float64}) Closest candidates are: where(::Any; name) at /Users/quentinvandegeer/.julia/v0.6/TensorFlow/src/ops/imported_ops.jl:3326

Stacktrace: [1] include_string(::String, ::String) at ./loading.jl:515

Strangely if you give tf.where one boolean tensor it returns a single value

qvdgeer avatar Apr 20 '18 20:04 qvdgeer

use select(mask, B, A).

We should probably add this as an second overload to where since this keep catches people out.

I think it comes down to our where does thee find operation to return the nonzero indices. (I feel like it used to do this when given 1 arg in python too, but cur docs say no.)

oxinabox avatar Apr 21 '18 00:04 oxinabox

Many thanks, I will give this a try.

qvdgeer avatar Apr 21 '18 09:04 qvdgeer

select is what the operation is called in the libtensorflow.

Python tensorflow does basically have a conditional, that results in where something hitting the libtensorflow where and other times it hitting the libtensorflor select. And python does not expose select directly

oxinabox avatar May 30 '18 06:05 oxinabox