TensorFlow.jl
TensorFlow.jl copied to clipboard
trouble using tf.where
Hi, trying to get a toy example with tf.where working using TensorFlow
sess = TensorFlow.Session()
A = TensorFlow.Variable([-1,-1,-1,-1,0,-1]) B = TensorFlow.Variable([0,0,0,0,1,0])
mask = TensorFlow.equal(A,0)
C= TensorFlow.where(mask,B,A)
run(sess, TensorFlow.global_variables_initializer()) run(sess,[A,B,mask,C])
but I get this
MethodError: no method matching where(::TensorFlow.Tensor{Bool}, ::TensorFlow.Variables.Variable{Float64}, ::TensorFlow.Variables.Variable{Float64}) Closest candidates are: where(::Any; name) at /Users/quentinvandegeer/.julia/v0.6/TensorFlow/src/ops/imported_ops.jl:3326
Stacktrace: [1] include_string(::String, ::String) at ./loading.jl:515
Strangely if you give tf.where one boolean tensor it returns a single value
use select(mask, B, A)
.
We should probably add this as an second overload to where
since this keep catches people out.
I think it comes down to our where
does thee find
operation to return the nonzero indices. (I feel like it used to do this when given 1 arg in python too, but cur docs say no.)
Many thanks, I will give this a try.
select
is what the operation is called in the libtensorflow.
Python tensorflow does basically have a conditional,
that results in where
something hitting the libtensorflow where
and other times it hitting the libtensorflor select
.
And python does not expose select
directly