PySyft-TensorFlow
PySyft-TensorFlow copied to clipboard
Send custom model without having to call before model.predict(dummy_data)
If you look at Part 2 tutorial, for custom models (tf.keras.models.Model
), before sending the model to the worker, we need to run model.predict(dummy_data)
to set the input_shape ( required by tf.keras.models.save_model).
Ideally we would like to remove this step or just have to call model(dummy_data)
before sending the model. You can find more information in this conversation.
We can set the input shape while defining the model as shown below:
class CustomModel(tf.keras.Model):
def __init__(self, num_classes=10):
super(CustomModel, self).__init__(name='custom_model')
self.num_classes = num_classes
self.flatten = tf.keras.layers.Flatten()
self.dense_1 = tf.keras.layers.Dense(128, activation='relu')
self.dense_2 = tf.keras.layers.Dense(num_classes, activation='softmax')
# set input shape
self._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense_1(x)
return self.dense_2(x)
model = CustomModel(10)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model_ptr = model.send(bob)
model_ptr.fit(x_train_ptr, y_train_ptr, epochs=2, validation_split=0.2)
Or we can just replace model.predict(dummy_data)
with model._set_inputs(tf.TensorSpec(shape=[None, 28, 28], dtype=tf.float32))
Would any of these be a satisfactory solution?
Hmm, I don't think this is ideal, since that method _set_inputs
is meant to be internal and not exposed to the user. Then again, I do like placing that in the constructor a bit more than model.predict(x)
for the tutorial. I just reviewed the conversation @yanndupis & I had in the original PR, if Keras is explicitly requiring that their users call fit
, predict
, or _set_inputs
, then I think it's okay for us to expect the same as well.
The only thing left to change here would be to handle this a bit more cleanly in the case of model.send(bob)
. It would be great if we had our own error to report & redirect, since a user might not realize that sending a model has this call to save_model
, which could be confusing.