sonnet
sonnet copied to clipboard
Conv2D initialize error in tf function
Hi, I am trying to write a gan with a convnet component with sonnet.
This is my base Conv function (there's some extra functionality to it but I'm trying to give a minimal example)
class Conv2D(snt.Conv2D):
def __init__(self, *args, **kwargs):
super().__init__(*args,**kwargs)
def __call__(self, tensor):
output = super().__call__(tensor)
return(output)
This is the simplified step function:
@tf.function
def step(batch_inputs,batch_targets,generator_obj,gen_optimizer):
return gen_step(batch_inputs,batch_targets,generator_obj,gen_optimizer)
and this is the simplified gen_step function:
def gen_step(batch_inputs, batch_targets, generator_obj,gen_optimizer):
with tf.GradientTape() as tape:
batch_predictions = generator_obj(batch_inputs)
return batch_predictions
When I run this I get an error code from Conv2d. Here is the trace:
/lustre/home/mo-txirouch/google/layers.py:73 __call__ *
output = super().__call__(tensor)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/utils.py:27 _decorate_unbound_method *
return decorator_fn(bound_method, self, args, kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope *
return method(*args, **kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/conv.py:117 __call__ *
self._initialize(inputs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/utils.py:27 _decorate_unbound_method *
return decorator_fn(bound_method, self, args, kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/once.py:93 wrapper *
_check_no_output(wrapped(*args, **kwargs))
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/utils.py:27 _decorate_unbound_method *
return decorator_fn(bound_method, self, args, kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope *
return method(*args, **kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/conv.py:143 _initialize *
self.w = self._make_w()
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/utils.py:27 _decorate_unbound_method *
return decorator_fn(bound_method, self, args, kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/base.py:272 wrap_with_name_scope *
return method(*args, **kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/sonnet/src/conv.py:167 _make_w *
return tf.Variable(self.w_init(weight_shape, self._dtype), name="w")
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:262 __call__ **
return cls._variable_v2_call(*args, **kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/tensorflow/python/ops/variables.py:67 getter
return captured_getter(captured_previous, **kwargs)
/home/mo-txirouch/.conda/envs/google/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py:769 invalid_creator_scope
"tf.function-decorated function tried to create "
ValueError: tf.function-decorated function tried to create variables on non-first call.
From prints I gathered the @tf.function runs again, presumably because of the autographing, which in turn causes the generator to rerun and leads to this error.
I don't know if this is an issue with the convolution in the sonnet library or with my implementation/workflow but after hacking at it for several days any advice would be appreciated. Thank you!