TensorFlow-Advanced-Segmentation-Models
TensorFlow-Advanced-Segmentation-Models copied to clipboard
Support for 'mixed_float16'
Saw your article in Medium, from October 7, and wanted to try it out on my 3090. (Training in 143.27s and Inference in 196.57s, for the Semantic Segmentation). Then I wanted to test with 16-bit floats, by adding:
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
However, the code crashes in model.fit due to lack of float16 support in your custom_layers_and_blocks:
File ~/anaconda3/lib/python3.9/site-packages/tensorflow_advanced_segmentation_models/models/_custom_layers_and_blocks.py:252, in AtrousSpatialPyramidPoolingV3.call(self, input_tensor, training)
249 z = self.atrous_sepconv_bn_relu_3(input_tensor, training=training)
251 # concatenation
--> 252 net = tf.concat([glob_avg_pool, w, x, y, z], axis=-1)
253 net = self.conv_reduction_1(net, training=training)
255 return net
InvalidArgumentError: Exception encountered when calling layer 'atrous_spatial_pyramid_pooling_v3_4' (type AtrousSpatialPyramidPoolingV3).
cannot compute ConcatV2 as input #1(zero-based) was expected to be a float tensor but is a half tensor [Op:ConcatV2] name: concat
Call arguments received by layer 'atrous_spatial_pyramid_pooling_v3_4' (type AtrousSpatialPyramidPoolingV3):
• input_tensor=tf.Tensor(shape=(16, 40, 40, 288), dtype=float16)
• training=True
Any chance of an update to support float16?