[Help Request, CHAP 14 Related] Implementing unet with TensorFlow subclassing API does not work
I finished chapter 14 of the book (Computer Vision). At the end of each chapter, I implement one of the state-of-the-art happenings to test my knowledge. However, this time my implementation did not work, and I posted on Stack Overflow and TensorFlow community, no answers yet. I know this issue is not directly related to the Book. But, I put it here if one studying chapter 14 is interested to take the challenge and see what I am doing wrong.
@ageron I know you are very busy, but I really appreciate if you take look at it.
Problem
I am trying to implement Unet with TensorFlow subclassing API and something does not seem to work properly, and I get the following error:
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Furthermore, I am uncertain if I have correctly implemented the logic inside the call() function. Any help to correct my mistakes would be much appreciated.
Here I am attaching the full copy of the implementation and the error tracks:
Code Implementation:
from functools import partial
keras.backend.clear_session()
tf.random.set_seed(42)
np.random.seed(42)
conv2d = partial(keras.layers.Conv2D, kernel_size = 3,
padding = 'SAME',
kernel_initializer = 'he_normal',
use_bias = False)
conv2dtranspose = partial(keras.layers.Conv2DTranspose,
kernel_size = 2, strides = 2,
padding = 'SAME')
class encoder(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super(encoder, self).__init__(**kwargs)
self.convs = [
conv2d(filters),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
conv2d(filters),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu')
]
def call(self, inputs):
Z = inputs
for layer in self.convs:
Z = layer(Z)
return Z
class UNet(keras.models.Model):
def __init__(self, filters, inputs_shape = [128, 128, 1], **kwargs):
super(UNet, self).__init__(**kwargs)
self.filters = filters
self.inputs = keras.layers.Input(shape = inputs_shape)
self.maxpool2d = keras.layers.MaxPool2D(pool_size = (2, 2), strides = 2)
self.conv2dtranspose = conv2dtranspose
self.concat = keras.layers.Concatenate()
def call(self, inputs):
skips = {}
Z, inpt = inputs
#implementing encoder path
for fId in range(len(self.filters)):
Z = encoder(filters = self.filters[fId])(Z)
if fId < len(self.filters) - 1:
skips[fId] = Z
Z = self.maxpool2d(Z)
#implementing decoder path
for fId in reversed(range(len(self.filters) - 1)):
Z = self.conv2dtranspose(self.filters[fId])(Z)
Z = self.concat([Z, skips[fId]])
Z = encoder(self.filters[::-1][fId])(Z)
output = keras.layers.Conv2D(1, kernel_size = 1, activation = 'sigmoid')(Z)
return keras.Model(inputs = [inpt], outputs = [output])
filters = [64, 128, 256, 512]
inpt = keras.layers.Input(shape = [128, 128, 1])
model = UNet(filters = filters)(inpt)
#Generating some test data
x = tf.random.normal(shape = (10, 128, 128, 1))
y = tf.random.normal(shape = (10, 128, 128, 1))
model.compile(loss = 'binary_crossentropy', optimizer = keras.optimizers.SGD(), metrics = ['accuracy'])
model.fit(x, y, epochs = 3)
Error Tracks:
WARNING:tensorflow:AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
446 program_ctx = converter.ProgramContext(options=options)
--> 447 converted_f = _convert_actual(target_entity, program_ctx)
448 if logging.has_verbosity(2):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _convert_actual(entity, program_ctx)
283
--> 284 transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
285
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform(self, obj, user_context)
285 if inspect.isfunction(obj) or inspect.ismethod(obj):
--> 286 return self.transform_function(obj, user_context)
287
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
469 # TODO(mdan): Confusing overloading pattern. Fix.
--> 470 nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
471
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
362 node = self._erase_arg_defaults(node)
--> 363 result = self.transform_ast(node, context)
364
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in transform_ast(self, node, ctx)
251 unsupported_features_checker.verify(node)
--> 252 node = self.initial_analysis(node, ctx)
253
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in initial_analysis(self, node, ctx)
238 graphs = cfg.build(node)
--> 239 node = qual_names.resolve(node)
240 node = activity.resolve(node, ctx, None)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in resolve(node)
251 def resolve(node):
--> 252 return QnResolver().visit(node)
253
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
446 if isinstance(value, AST):
--> 447 value = self.visit(value)
448 if value is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
446 if isinstance(value, AST):
--> 447 value = self.visit(value)
448 if value is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
455 elif isinstance(old_value, AST):
--> 456 new_node = self.visit(old_value)
457 if new_node is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
455 elif isinstance(old_value, AST):
--> 456 new_node = self.visit(old_value)
457 if new_node is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
446 if isinstance(value, AST):
--> 447 value = self.visit(value)
448 if value is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
455 elif isinstance(old_value, AST):
--> 456 new_node = self.visit(old_value)
457 if new_node is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in visit_Subscript(self, node)
231 s = node.slice
--> 232 if not isinstance(s, gast.Index):
233 # TODO(mdan): Support range and multi-dimensional indices.
AttributeError: module 'gast' has no attribute 'Index'
During handling of the above exception, another exception occurred:
OperatorNotAllowedInGraphError Traceback (most recent call last)
<ipython-input-449-e6f92329b0db> in <module>
2
3 inpt = keras.layers.Input(shape = [128, 128, 1])
----> 4 model = UNet(filters = filters)(inpt)
5
6 #Generating some test data
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
944 # >> model = tf.keras.Model(inputs, outputs)
945 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 946 return self._functional_construction_call(inputs, args, kwargs,
947 input_list)
948
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
1083 layer=self, inputs=inputs, build_graph=True, training=training_value):
1084 # Check input assumptions set after layer building, e.g. input shape.
-> 1085 outputs = self._keras_tensor_symbolic_call(
1086 inputs, input_masks, args, kwargs)
1087
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
815 return nest.map_structure(keras_tensor.KerasTensor, output_signature)
816 else:
--> 817 return self._infer_output_signature(inputs, args, kwargs, input_masks)
818
819 def _infer_output_signature(self, inputs, args, kwargs, input_masks):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
856 # TODO(kaftan): do we maybe_build here, or have we already done it?
857 self._maybe_build(inputs)
--> 858 outputs = call_fn(inputs, *args, **kwargs)
859
860 self._handle_activity_regularization(inputs, outputs)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
665 try:
666 with conversion_ctx:
--> 667 return converted_call(f, args, kwargs, options=options)
668 except Exception as e: # pylint:disable=broad-except
669 if hasattr(e, 'ag_error_metadata'):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
452 if is_autograph_strict_conversion_mode():
453 raise
--> 454 return _fall_back_unconverted(f, args, kwargs, options, e)
455
456 with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _fall_back_unconverted(f, args, kwargs, options, exc)
499 logging.warn(warning_template, f, file_bug_message, exc)
500
--> 501 return _call_unconverted(f, args, kwargs, options)
502
503
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
476
477 if kwargs is not None:
--> 478 return f(*args, **kwargs)
479 return f(*args)
480
<ipython-input-448-ce9f55fd84b1> in call(self, inputs)
49 skips = {}
50
---> 51 Z, inpt = inputs
52
53 #implementing encoder path
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
503 def __iter__(self):
504 if not context.executing_eagerly():
--> 505 self._disallow_iteration()
506
507 shape = self._shape_tuple()
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_iteration(self)
499 else:
500 # Default: V1-style Graph execution.
--> 501 self._disallow_in_graph_mode("iterating over `tf.Tensor`")
502
503 def __iter__(self):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_in_graph_mode(self, task)
477
478 def _disallow_in_graph_mode(self, task):
--> 479 raise errors.OperatorNotAllowedInGraphError(
480 "{} is not allowed in Graph execution. Use Eager execution or decorate"
481 " this function with @tf.function.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Hi @ethenkaufmann ,
Thanks for your feedback. There are a few issues in the code:
- When defining a model or a layer using subclassing, you should not create
Input(so no need to passinput_shapeto the constructor either): just use theinputsargument of thecall()method. Also, the call method should just returnoutput, it should not create a newModel. - Every layer containing parameters should be created in the
__init__()method (or in thebuild()method). Never create a layer with parameters inside thecall()method. Currently, your code does that in several places, for exampleencoder(filters = self.filters[fId])(Z)creates a newencoder(btw, by convention classes should start with a capital letter). Similarly,self.conv2dtranspose(self.filters[fId])(Z)creates a new layer with parameters as well. All these layers should be created ahead of time, in__init__()(orbuild()). Z, inpt = inputswon't work since there's just one input. Perhaps you meantZ = inpt = inputs? In any case you don't needinptanymore since you must not create aModelinstance at the end of thecall()method, just returnoutput. SoZ = inputsis sufficient.
Could you try fixing these issues and tell me if it helps? If not, please post the updated code.
Cheers
For this model, the Functional API is probably much easier. For example, UNet is implemented here: https://keras.io/examples/vision/oxford_pets_image_segmentation/
Alternatively, if the goal is to learn how to use the Subclassing API, I suggest you start with a very basic model first. For example, a model with just one Dense layer. Then add a couple more layers. Then add a skip connection. Etc. Work your way slowly to a complex model.
You'll probably want to define a build() method. See the book and/or notebook for an example.