tensorflow-onnx icon indicating copy to clipboard operation
tensorflow-onnx copied to clipboard

tf.unstack / tf.split on input loses dynamic batch size

Open cchan-lm opened this issue 3 years ago • 3 comments

Describe the bug When tf.unstack or tf.split is used on an input, dynamic batch dimension is lost.

This was found while trying to export NodLabs's DLRM. An example script is provided to generate the model's architecture for saving. Inputs are [dense_features, sparse_features]. The inputs are split up in the model's call method. The resulting ONNX model shows that dense_features kept its dynamic batch dimension, but spare_features has been rolled out explicitly and its dynamic batch dimension is lost.

Urgency Supporting external customers with guidance on how to export TF to ONNX, so we would like to know very soon. Understandably, it's the holiday season, so if there's resolution/workaround by mid-January, that would be great :) Thank you for any assistence!

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS Big Sur 11.5.1 and Ubuntu 16.04
  • Tensorflow Version: 2.6.0 and 2.7.0 (cpu)
  • Python version: 3.9.0
  • tf2onnx version: nightly (fd93ca8)

To Reproduce

  1. dlrm_example.py:
# Source: https://github.com/NodLabs/tensorflow-dlrm/blob/master/noddlrm/recommenders/dlrm.py

import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Layer, Embedding


class LatentFactor(Embedding):
    
    def __init__(self, num_instances, dim, zero_init=False, name=None):
        
        if zero_init:
            initializer = 'zeros'
        else:
            initializer = 'uniform'
        super(LatentFactor, self).__init__(input_dim=num_instances, 
                                           output_dim=dim, 
                                           embeddings_initializer=initializer,
                                           name=name)
    
    def censor(self, censor_id):
        
        unique_censor_id, _ = tf.unique(censor_id)
        embedding_gather = tf.gather(self.variables[0], indices=unique_censor_id)
        norm = tf.norm(embedding_gather, axis=1, keepdims=True)
        return self.variables[0].scatter_nd_update(indices=tf.expand_dims(unique_censor_id, 1), 
                                                   updates=embedding_gather / tf.math.maximum(norm, 0.1))


def MLP(units_list, use_bias=True, activation='relu', out_activation=None):
    
    mlp = Sequential()
    
    for units in units_list[:-1]:
        mlp.add(Dense(units, 
                        activation=activation, 
                        use_bias=use_bias))
    
    mlp.add(Dense(units_list[-1], 
                activation=out_activation, 
                use_bias=use_bias))
    
    return mlp


class SecondOrderFeatureInteraction(Layer):
    
    def __init__(self, self_interaction=False):
        
        self._self_interaction = self_interaction
        
        super(SecondOrderFeatureInteraction, self).__init__()
    
    def call(self, inputs):
        
        '''
        inputs: list of features with shape [batch_size, feature_dim]
        '''
        
        batch_size = tf.shape(inputs[0])[0]
        
        concat_features = tf.stack(inputs, axis=1)
        dot_products = tf.linalg.LinearOperatorLowerTriangular(tf.matmul(concat_features, concat_features, transpose_b=True)).to_dense()

        ones = tf.ones_like(dot_products)
        mask = tf.linalg.band_part(ones, 0, -1)
        
        if not self._self_interaction:
            mask = mask - tf.linalg.band_part(ones, 0, 0)
            out_dim = int(len(inputs) * (len(inputs)-1) / 2)
        else:
            out_dim = int(len(inputs) * (len(inputs)+1) / 2)
        
        flat_interactions = tf.reshape(tf.boolean_mask(dot_products, mask), (batch_size, out_dim))
            
        return flat_interactions


# Method 4
# class SparseProcessorLayer(tf.keras.layers.Layer):
#     def __init__(self, latent_factors):
#         super().__init__()
#         self._latent_factors = latent_factors
#     def call(self, sparse_features):
#         sparse_emb_vecs = list(map(lambda pair: pair[1](pair[0]),
#                                       zip(tf.unstack(sparse_features, axis=1), 
#                                           self._latent_factors)))
#         return sparse_emb_vecs


class DLRM(Model):
    
    def __init__(
        self, 
        m_spa,
        ln_emb,
        ln_bot,
        ln_top,
        arch_interaction_op='dot',
        arch_interaction_itself=False,
        sigmoid_bot=False,
        sigmoid_top=True,
        loss_func='mse',
        loss_threshold=0.0):
        
        '''
        m_spa: the dimensionality of sparse feature embeddings
        ln_emb: the size of sparse feature embeddings (num_instances)
        ln_bot: the size of the bottom MLP
        ln_top: the size of the top MLP
        '''
        
        super(DLRM, self).__init__()
        
        self._loss_threshold = loss_threshold
        self._loss_func = loss_func
        self._latent_factors = [LatentFactor(num_instances=num, 
                                             dim=m_spa) for num in ln_emb]

        # For Method 4
        # sparse_input = tf.keras.Input(len(ln_emb))
        # sparse_output = SparseProcessorLayer(self._latent_factors)(sparse_input)
        # self.sparse_processor = tf.keras.Model(sparse_input, sparse_output)

        self._mlp_bot = MLP(units_list=ln_bot, 
                            out_activation='sigmoid' if sigmoid_bot else 'relu')
        self._mlp_top = MLP(units_list=ln_top, 
                            out_activation='sigmoid' if sigmoid_top else 'relu')
        
        self._dot_interaction = None
        if arch_interaction_op == 'dot':
            self._dot_interaction = SecondOrderFeatureInteraction(
                                        self_interaction=arch_interaction_itself
                                    )
        
        elif self._arch_interaction_op != 'cat':
            sys.exit(
                "ERROR: arch_interaction_op="
                + self._arch_interaction_op
                + " is not supported"
            )
        
        if loss_func == 'mse':
            self._loss = tf.keras.losses.MeanSquaredError()
        elif loss_func == 'bce':
            self._loss = tf.keras.losses.BinaryCrossentropy()
        else:
            sys.exit(
                "ERROR: loss_func="
                + loss_func
                + " is not supported"
            )
        
    def get_myloss(self, dense_features, sparse_features, label):
        
        '''
        dense_features shape: [batch_size, num of dense features]
        sparse_features shape: [batch_size, num_of_sparse_features]
        label shape: [batch_size]
        '''
        
        prediction = self.inference(dense_features, sparse_features)
        loss = self._loss(y_true=label, 
                          y_pred=prediction)
        return loss

    def call(self, inputs, training=None, mask=None):
        dense_features, sparse_features = inputs
        return self.inference(dense_features, sparse_features)

    def inference(self, dense_features, sparse_features):
    
        '''
        dense_features shape: [batch_size, num of dense features]
        sparse_features shape: [num_of_sparse_features, batch_size]
        '''
        self._set_inputs([dense_features, sparse_features])

        # Original method:
        sparse_emb_vecs = list(map(lambda pair: pair[1](pair[0]),
                                      zip(tf.unstack(sparse_features, axis=1), 
                                          self._latent_factors)))

        # Method 1 - don't use map + lambda
        # sparse_emb_vecs = [None]*len(self._latent_factors)
        # sparse_unstacked = tf.unstack(sparse_features, axis=1)
        # for i, latent_factor in enumerate(self._latent_factors):
        #     sparse_emb_vecs[i] = latent_factor(sparse_unstacked[i])

        # Method 2 - use tf.split
        # sparse_unstacked = tf.split(sparse_features, len(self._latent_factors), axis=1)
        # sparse_unstacked = list(map(lambda x: tf.reshape(x, [-1]), sparse_unstacked))
        # sparse_emb_vecs = list(map(lambda pair: pair[1](pair[0]),
        #                             zip(sparse_unstacked, self._latent_factors)))
    
        # Method 3 - use tf.split without map + lambda
        # sparse_emb_vecs = [None]*len(self._latent_factors)
        # sparse_unstacked = tf.split(sparse_features, len(self._latent_factors), axis=1)
        # for i, latent_factor in enumerate(self._latent_factors):
        #     sparse_emb_vecs[i] = latent_factor(tf.reshape(sparse_unstacked[i], [-1]))

        # Method 4 - use submodel
        # sparse_emb_vecs = self.sparse_processor(sparse_features)

        dense_emb_vec = self._mlp_bot(dense_features)
        
        if self._dot_interaction is not None:
            prediction = self._mlp_top(tf.concat([dense_emb_vec, 
                                              self._dot_interaction(sparse_emb_vecs + [dense_emb_vec])],
                                             axis=1))
        else:
            prediction = self._mlp_top(tf.concat(sparse_emb_vecs + [dense_emb_vec], 
                                             axis=1))
        
        if 0.0 < self._loss_threshold and self._loss_threshold < 1.0:
            prediction = tf.clip_by_value(prediction, self._loss_threshold, 1.0 - self._loss_threshold)
        
        return tf.reshape(prediction, [-1])


def test_dlrm():
    dim_embed = 4
    bottom_mlp_size = [8, 4]
    top_mlp_size = [128, 64, 1]

    # dense_features shape: [batch_size, num of dense features]
    # sparse_features shape: [num_of_sparse_features, batch_size]
    # Shapes and types below were gleamed from processed Criteo dataset

    # See:
    # https://github.com/facebookresearch/dlrm/blob/main/data_utils.py
    # https://github.com/NodLabs/tensorflow-dlrm/blob/master/dataloader.py
    # https://github.com/NodLabs/tensorflow-dlrm/blob/master/dlrm_criteo_gpu.py

    x_int = [[0, 97, 0, 47, 34, 0, 0, 0, 0, 0, 7, 21785, 0],
             [5, 351, 6, 5, 0, 7, 0, 4, 5, 3, 0, 33, 5],
             [5, 339, 0, 0, 0, 0, 0, 144, 0, 0, 0, 44628, 0],
             [1, 550, 0, 0, 78, 0, 0, 40, 5,0, 12, 1228, 0]]
    x_cat = [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
             [2., 2., 2., 2., 0., 1., 2., 2., 0., 2., 2., 2., 2., 1., 2., 1., 1., 2., 1., 2., 2., 2., 2., 2., 0., 0.],
             [3., 3., 3., 3., 0., 1., 3., 3., 0., 3., 3., 3., 1., 2., 3., 2., 2., 3., 1., 3., 3., 3., 3., 3., 0., 1.]]

    counts = [97, 99, 99, 87, 95, 2, 99, 88, 14, 95, 94, 98, 9, 57, 96, 15, 4, 40, 12, 97, 97, 97, 90, 98, 12, 17]
    dense_features = np.log(x_int).astype(np.float32) 
    sparse_features = x_cat

    dlrm_model = DLRM(
                    m_spa=dim_embed,
                    ln_emb=counts,
                    ln_bot=bottom_mlp_size,
                    ln_top=top_mlp_size
                    )

    # Model does not have Input layer, have to pass input in order to save model
    dlrm_model([dense_features, sparse_features])
    dlrm_model.save("dlrm")


if __name__ == "__main__":
    test_dlrm()
  1. Run python3 dlrm_example.py on the provided dlrm_example.py. This will save the TensorFlow model. This is for purely architecture.
  • In this file, I've commented a few methods that I tried instead of the original:
    • Method 1: tf.unstack without map + lambda
    • Method 2: tf.split with map + lambda
    • Method 3: tf.split without map + lambda
    • Method 4: Use a submodel that uses a specified tf.keras.Input tensor with dynamic batch size
  1. Run python3 -m tf2onnx.convert --saved-model dlrm --output dlrm.onnx --opset 16 --verbose

Screenshots Model inputs are logged as below, showing that the 2nd input had been unrolled and has lost its dynamic batch dimension:

2021-12-16 17:46:44,643 - INFO - Model inputs: ['input_1', 'input_2_1_1', 'input_2_1_10', 'input_2_1_11', 'input_2_1_12', 'input_2_1_13', 'input_2_1_14', 'input_2_1_15', 'input_2_1_16', 'input_2_1_17', 'input_2_1_18', 'input_2_1_19', 'input_2_1_2', 'input_2_1_20', 'input_2_1_21', 'input_2_1_22', 'input_2_1_23', 'input_2_1_24', 'input_2_1_25', 'input_2_1_26', 'input_2_1_3', 'input_2_1_4', 'input_2_1_5', 'input_2_1_6', 'input_2_1_7', 'input_2_1_8', 'input_2_1_9', 'input_2_2_1', 'input_2_2_10', 'input_2_2_11', 'input_2_2_12', 'input_2_2_13', 'input_2_2_14', 'input_2_2_15', 'input_2_2_16', 'input_2_2_17', 'input_2_2_18', 'input_2_2_19', 'input_2_2_2', 'input_2_2_20', 'input_2_2_21', 'input_2_2_22', 'input_2_2_23', 'input_2_2_24', 'input_2_2_25', 'input_2_2_26', 'input_2_2_3', 'input_2_2_4', 'input_2_2_5', 'input_2_2_6', 'input_2_2_7', 'input_2_2_8', 'input_2_2_9', 'input_2_3_1', 'input_2_3_10', 'input_2_3_11', 'input_2_3_12', 'input_2_3_13', 'input_2_3_14', 'input_2_3_15', 'input_2_3_16', 'input_2_3_17', 'input_2_3_18', 'input_2_3_19', 'input_2_3_2', 'input_2_3_20', 'input_2_3_21', 'input_2_3_22', 'input_2_3_23', 'input_2_3_24', 'input_2_3_25', 'input_2_3_26', 'input_2_3_3', 'input_2_3_4', 'input_2_3_5', 'input_2_3_6', 'input_2_3_7', 'input_2_3_8', 'input_2_3_9', 'input_2_4_1', 'input_2_4_10', 'input_2_4_11', 'input_2_4_12', 'input_2_4_13', 'input_2_4_14', 'input_2_4_15', 'input_2_4_16', 'input_2_4_17', 'input_2_4_18', 'input_2_4_19', 'input_2_4_2', 'input_2_4_20', 'input_2_4_21', 'input_2_4_22', 'input_2_4_23', 'input_2_4_24', 'input_2_4_25', 'input_2_4_26', 'input_2_4_3', 'input_2_4_4', 'input_2_4_5', 'input_2_4_6', 'input_2_4_7', 'input_2_4_8', 'input_2_4_9']

As viewed in Netron:

dlrm_tf2onnx_1 dlrm_tf2onnx_2

cchan-lm avatar Dec 16 '21 18:12 cchan-lm

Any news on this? I also have a split node in my ONNX graph, and it's breaking the network when I use a dynamic batch size

montmejat avatar Apr 26 '22 15:04 montmejat

I have not yet found a workaround due to other efforts but do have to revisit this... @fatcat-z do you know of anything?

cchan-lm avatar Apr 26 '22 18:04 cchan-lm

I have not yet found a workaround due to other efforts but do have to revisit this... @fatcat-z Jay Zhang FTE do you know of anything?

I just noticed the saved model generated after calling dlrm_model.save() method has already change the inputs as you provided, so I believe it's not changed by tf2onnx. When we load the saved model in tf2onnx, the inputs have been there already.

fatcat-z avatar Aug 26 '22 09:08 fatcat-z