fold
fold copied to clipboard
TypeError for td.OneOf
I am using td.OneOf to set up terminal condition for recursion, which means I will implement two different blocks based on whether I get empty list or not.
td.OneOf(lambda pair: pair[0] == [], (add_metrics(is_train, is_empty=False), add_metrics(is_train, is_empty=True)))
But I get "TypeError: bad output type PyObjectType for <td.Composition.input 'metrics'>, expected TupleType". It seems that a PyObjectType is fed into add_metrics, but I am pretty sure that the output before td.OneOf is TupleType.
Thanks!
Hi, I think I solved the problem.
The basic usage of td.OneOf is:
td.OneOf(key_fn, case_blocks)
,
where key_fn
is either a python function or a block.
In a python function case, the Fold transforms it into a td.InputTransform
block with PyObjectType
inputs and outputs.
Remember that all the input types of key_fn
, case_blocks
, and td.AllOf
must keep the same, because the output before td.OneOf
is also fed into key_fn
as input. When initiating td.OneOf
, Fold propagates the input type from key_fn
block to td.OneOf
and then case_blocks
. See line 1628 of blocks.py
So, here is the contradiction for case_blocks
between your input type TupleType
and the expected PyObjectType
.
See the following example.
a=['1',['3', '4']]
b=['2',['3', '4']]
block1 = (td.Scalar('int32'), td.Scalar('int32'))
block2 = td.Function(tf.add)
block3 = td.Function(tf.multiply)
oneof = (td.Identity(), block1) >> td.OneOf(key_fn=lambda x: x=='1', case_blocks=(td.GetItem(1) >> block2, td.GetItem(1) >> block3))
We will get "TypeError: bad input type PyObjectType for <td.Function tf_fn='add'>, expected TupleType or TensorType". It can be solved through this:
oneof = (td.Identity(), block1) >> td.OneOf(key_fn=(td.GetItem(0) >> td.InputTransform(lambda x: x=='1')), case_blocks=(td.GetItem(1) >> block2, td.GetItem(1) >> block3))
oneof.eval(a)
=> array(12, dtype=int32)
oneof.eval(b)
=> array(7, dtype=int32)