tfjs icon indicating copy to clipboard operation
tfjs copied to clipboard

TensorScatterUpdate unsupported operation

Open FabioRomagnolo opened this issue 3 years ago • 3 comments

Hi! I'm developing a neural network in Keras which uses the tensor_scatter_nd_update function. The problem is that when I convert the graph to TensorflowJS and I try to execute the inference, I get the unsupported TensorScatterUpdate error. I tried to reimplement it by myself in Python and Javascript, but inference slows down too much, WebGL crashes or TFJS has other compability issues.

Is there any hope to have this operation implemented? Or at least to have an alternative optimized function doing the same thing?

These are two of the equivalent implementations I've written, but they're not optimized and they have many compability problems with TFJS, so they're not useful at all.

    def custom_tensor_scatter_nd_update(self, tensor, indices, updates):
        """
        Alternative method to tf.tensor_scatter_nd_update using the more standard tf.scatter_nd function, which is
        broadly better supported converting the operation to TensorflowJS.
        Docs:
        -   tf.tensor_scatter_nd_update: https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
        -   tf.scatter_nd: https://www.tensorflow.org/api_docs/python/tf/scatter_nd
        :param tensor: Existing tensor on which scatter.
        :param indices: Indices where to scatter.
        :param updates: Scattering tensor.
        :return: Scattered tensor updated.
        """
        '''
        # Equivalent implementation using gather and scatter.
        # WARNING: WebGL may crash
        shape = tf.shape(tensor, out_type=indices.dtype)
        # a = tensor * tf.scatter_nd(indices, zeros, shape) # Zero problem here!
        # Zeroing the indices we want to update
        obsolete_values = tf.gather_nd(tensor, indices)
        a = tensor - tf.scatter_nd(indices, obsolete_values, shape)
        # Taking the updates
        b = tf.scatter_nd(indices, updates, shape)
        return a + b
        '''
        # Equivalent implementation using where.
        # WARNING: TFJS does not support tf.where with rank 6
        shape = tf.shape(tensor, out_type=indices.dtype)
        ones = tf.ones_like(updates)
        patch = tf.scatter_nd(indices, updates, shape)
        placeholders = tf.scatter_nd(indices, ones, shape)
        return tf.where(placeholders > 0, patch, tensor)

Any help will be very appreciated, thank you!

FabioRomagnolo avatar Aug 02 '22 12:08 FabioRomagnolo

Similar request https://github.com/tensorflow/tfjs/issues/6395 , https://github.com/tensorflow/tfjs/issues/4222

rthadur avatar Aug 02 '22 18:08 rthadur

@FabioRomagnolo thanks for the feature request, base on the definition of ScatterNdUpdate the input reference tensor should be a variable tensor. which is not supported in our GraphModel, can you share your converted model for us to verify? thanks

pyu10055 avatar Aug 05 '22 21:08 pyu10055

@FabioRomagnolo thanks for the feature request, base on the definition of ScatterNdUpdate the input reference tensor should be a variable tensor. which is not supported in our GraphModel, can you share your converted model for us to verify? thanks

Thank you for the reply! This is the original model without custom operations. Inside you should find the TensorScatterUpdate function: model.zip

FabioRomagnolo avatar Aug 08 '22 08:08 FabioRomagnolo