blocks icon indicating copy to clipboard operation
blocks copied to clipboard

support recurrent with no states.

Open Beronx86 opened this issue 8 years ago • 5 comments

The recurrent wrapper does not support loop with no states. But this kind of loop may be useful. So I modified the codes.

Fixes #1112

Beronx86 avatar Jun 08 '16 06:06 Beronx86

I'll let someone more familiar with recurrent do the review but I know they will ask for you to add a test. :smile:

dwf avatar Jun 08 '16 06:06 dwf

Ok, I'll write the test case.

Beronx86 avatar Jun 08 '16 06:06 Beronx86

In the original code, the recurrent method would require initial_states function whether there is recurrent states or not. If the recurrent states is left empty, an error would occur at the time of visiting initial_states function.

Beronx86 avatar Jun 12 '16 13:06 Beronx86

You may produce the error with the following code. The error occurs when the class does not contain a recurrent method named apply

import numpy
import theano
from numpy.testing import assert_allclose
from theano import tensor

from blocks.bricks import Brick
from blocks.bricks.recurrent import BaseRecurrent, recurrent
# from recurrent import recurrent


class RecurrentWrapperNoStatesClass(BaseRecurrent):
    def __init__(self, dim, **kwargs):
        super(RecurrentWrapperNoStatesClass, self).__init__(**kwargs)
        self.dim = dim

    def get_dim(self, name):
        if name in ['inputs', 'outputs', 'outputs_2']:
            return self.dim
        if name == 'mask':
            return 0
        return super(RecurrentWrapperNoStatesClass, self).get_dim(name)

    @recurrent(sequences=['inputs', 'mask'], states=[],
               outputs=['outputs', 'outputs_2'], contexts=[])
    def apply2(self, inputs, mask=None):
        outputs = inputs * 10
        outputs_2 = tensor.sqr(inputs)
        if mask:
            outputs *= mask
            outputs_2 *= mask
        return outputs, outputs_2


if __name__ == '__main__':
    recurrent_examples = RecurrentWrapperNoStatesClass(
        dim=11, name='test_example')

    X = tensor.tensor3('X')
    out, out_2 = recurrent_examples.apply2(inputs=X, mask=None)

    x_val = numpy.random.uniform(size=(5, 1, 1))
    x_val = numpy.asarray(x_val, dtype=theano.config.floatX)

    out_eval = out.eval({X: x_val})
    out_2_eval = out_2.eval({X: x_val})

    assert_allclose(x_val * 10, out_eval)
    assert_allclose(numpy.square(x_val), out_2_eval)

Beronx86 avatar Jun 12 '16 14:06 Beronx86

I don't understand, now you have removed your fix, and it is again not supported to have no states property. Why not just implemented like you did in the first place, but with more gentle changes to the code as I suggested?

rizar avatar Jun 21 '16 04:06 rizar