PySyft-TensorFlow icon indicating copy to clipboard operation
PySyft-TensorFlow copied to clipboard

Send custom model without having to call before model.predict(dummy_data)

Open yanndupis opened this issue 5 years ago • 2 comments

If you look at Part 2 tutorial, for custom models (tf.keras.models.Model), before sending the model to the worker, we need to run model.predict(dummy_data) to set the input_shape ( required by tf.keras.models.save_model).

Ideally we would like to remove this step or just have to call model(dummy_data) before sending the model. You can find more information in this conversation.

yanndupis avatar Oct 22 '19 18:10 yanndupis

We can set the input shape while defining the model as shown below:

class CustomModel(tf.keras.Model):

    def __init__(self, num_classes=10):
        super(CustomModel, self).__init__(name='custom_model')
        self.num_classes = num_classes
        
        self.flatten = tf.keras.layers.Flatten()
        self.dense_1 = tf.keras.layers.Dense(128, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(num_classes, activation='softmax')
        
        # set input shape
        self._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense_1(x)
        return self.dense_2(x)
              
model = CustomModel(10)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model_ptr = model.send(bob)
model_ptr.fit(x_train_ptr, y_train_ptr, epochs=2, validation_split=0.2)

Or we can just replace model.predict(dummy_data) with model._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))

Would any of these be a satisfactory solution?

arshjot avatar Nov 05 '19 23:11 arshjot

Hmm, I don't think this is ideal, since that method _set_inputs is meant to be internal and not exposed to the user. Then again, I do like placing that in the constructor a bit more than model.predict(x) for the tutorial. I just reviewed the conversation @yanndupis & I had in the original PR, if Keras is explicitly requiring that their users call fit, predict, or _set_inputs, then I think it's okay for us to expect the same as well.

The only thing left to change here would be to handle this a bit more cleanly in the case of model.send(bob). It would be great if we had our own error to report & redirect, since a user might not realize that sending a model has this call to save_model, which could be confusing.

jvmncs avatar Nov 07 '19 15:11 jvmncs