sonnet icon indicating copy to clipboard operation
sonnet copied to clipboard

v2: `dynamic_unroll` using `lstm_with_recurrent_dropout` doesn't work

Open meta-inf opened this issue 4 years ago • 4 comments

The following code doesn't work:

import tensorflow as tf
import numpy as np
import sonnet as snt

class A():

    def __init__(self):
        self.train_core, test_core = snt.lstm_with_recurrent_dropout(1, dropout=0.5)

    @tf.function
    def forward(self, inp):
        return snt.dynamic_unroll(self.train_core, inp, self.train_core.initial_state(5))

a = A()
inp = tf.tile(tf.linspace(-1., 1., 20)[:,None,None], [1,5,1])
print(a.forward(inp))

promopting

ValueError: in converted code:

    test.py:13 forward  *
        return snt.dynamic_unroll(self.train_core, inp,
    /home/ziyu/sonnet/sonnet/src/utils.py:310 smart_autograph_wrapper  *
        return f_autograph(*args, **kwargs)
    /tmp/tmp8pmfpnpt.py:79 tf__dynamic_unroll
        state, output_tas, outputs = ag__.for_stmt(ag__.converted_call(tf.range, (1, num_steps), None, fscope), None, loop_body, get_state, set_state, (state, output_tas, outputs), ('state', 'output_tas', 'outputs'), ())
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:315 for_stmt
        composite_symbol_names)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:478 _tf_range_for_stmt
        opts=opts,
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/autograph/operators/control_flow.py:794 _tf_while_stmt
        aug_init_vars, **opts)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/ops/control_flow_ops.py:2675 while_loop
        back_prop=back_prop)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/ops/while_v2.py:77 while_loop
        expand_composites=True)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 map_structure
        structure[0], [func(*x) for x in entries],
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/util/nest.py:568 <listcomp>
        structure[0], [func(*x) for x in entries],
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/indexed_slices.py:318 internal_convert_to_tensor_or_indexed_slices
        return ops.convert_to_tensor(value, dtype=dtype, name=name, as_ref=as_ref)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1314 convert_to_tensor
        ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:317 _constant_tensor_conversion_function
        return constant(v, dtype=dtype, name=name)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:258 constant
        allow_broadcast=True)
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/constant_op.py:296 _constant_impl
        allow_broadcast=allow_broadcast))
    /home/ziyu/anaconda3/envs/tf2/lib/python3.6/site-packages/tensorflow_core/python/framework/tensor_util.py:439 make_tensor_proto
        raise ValueError("None values not supported.")

    ValueError: None values not supported.

Seems it is because you are trying to store a None value to the initial state of a while loop. Changing the None value on this line to 0. appears to fix this.

I'm using tensorflow-gpu 2.1.0 and python 3.6. For sonnet I've tried both 2.0.0b0 and the HEAD version (1f5c0c241).

meta-inf avatar Feb 22 '20 07:02 meta-inf

CC @superbobry

tomhennigan avatar Feb 22 '20 09:02 tomhennigan

Seems it is because you are trying to store a None value to the initial state of a while loop. Changing the None value on this line to 0. appears to fix this.

The line mentioned: rate = LSTMState(hidden=dropout, cell=None)

@meta-inf do you think that extracting cell / initial state to the parameters of lstm_with_recurrent_dropout can fix the issue?

So your usage gonna look something like: self.train_core, test_core = snt.lstm_with_recurrent_dropout(1, dropout=0.5, initial_state=0)

SergiiVolodko avatar Jun 29 '20 22:06 SergiiVolodko

@SergeyVolodko I'm not sure what you meant by initial_state, since the LSTMState created in lstm_with_recurrent_dropout is used to store the dropout rate, and is not an initial state. The initial state of _RecurrentDropoutWrapper consists of both the initial state of the base RNN core, as well as a dropout mask (see here), and it doesn't seem like we should change the entire state to 0.

I guess the easiest fix is to set that cell parameter to 0: it's a placeholder and is never used. But any change that makes my original code example work would be fine to me.

meta-inf avatar Jul 02 '20 14:07 meta-inf

Thanks, @meta-inf ! Indeed the parameter name I proposed is wrong. I will dive a bit deeper into the topic to get a better understanding of RNNs, LSTMs, and the code (any useful links are welcome). But the approach I'm proposing is kind of intuitive: make currently hardcoded variable a parameter to increase method flexibility. If somebody can confirm that extracting this parameter will make sense for other use cases - I'm happy to create a pull request with the improvement! Of course, you are more than welcome to suggest a more suitable name for this parameter :)

SergiiVolodko avatar Jul 05 '20 13:07 SergiiVolodko