keras
keras copied to clipboard
[torch backend] ValueError: Expected an object of type `Trackable`, such as `tf.Module` or a subclass of the `Trackable` class, for export.
With torch
backend, the keras model is unable to save in SavedModel format. Is it expected? If so, if I develop my code with torch backend and later want to convert to SM format, I have to ensure that the code is runable to both backend (cost). !!!
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras
class ComputeSum(keras.Model):
def __init__(self, input_dim, *args, **kwargs):
super().__init__(*args, **kwargs)
self.total = self.add_weight(
name='total',
initializer="zeros",
shape=(input_dim,),
trainable=False
)
def call(self, inputs):
self.total.assign_add(ops.sum(inputs, axis=0))
return self.total
x = ops.ones((2, 2))
my_sum = ComputeSum(2)
y = my_sum(x)
tf.saved_model.save(my_sum, '/tmp/saved_model2/')
----> tf.saved_model.save(my_sum, '/tmp/saved_model2/')
> ValueError: Expected an object of type `Trackable`, such as `tf.Module` or a subclass of the `Trackable` class, for export. Got <ComputeSum name=compute_sum_3, built=True> with type <class '__main__.ComputeSum'>.
This is also same if we try to load a savedmodel as follows other than tensorflow backend, it won't work
os.environ["KERAS_BACKEND"] = "torch"
keras.Sequential([
keras.layers.TFSMLayer(
'saved_model',
call_endpoint="serving_default"
)
])
# Error
import scipy
sparse_weights = scipy.sparse.csr_matrix(my_sum.get_weights())
print(sparse_weights)
How about obtaining the weight?
This might kinda look invalid as I'm using torch backend and yet using tf.saved_model.save
to get saved-model. But I was hoping keras would do some majic here :D
@innat So do i. 👍
My understanding is that, if you want to "switch backends" like this, the only way is to save the model as .keras, and reload it after having enabled another. This assumes that all custom layers are implemented with keras ops, and not directly in one of the backends.
+1 to what lbortolotti said above: use the .keras format to swwitch backends.
Exporting from PyTorch to SavedModel has been done elsewhere though: https://github.com/pytorch/xla/blob/r2.1/docs/stablehlo.md#convert-saved-stablehlo-for-serving
We might wan to explore that to implement the model.export functionality for the PyTorch backend.