spark-deep-learning
spark-deep-learning copied to clipboard
How to pass custom objects to KerasImageFileTransformer. keras load_model supports custom_objects
Hi, I am using dice_loss custom object to train my model. Is there anyway to pass custom objects to load model in spark DL? Or is it that spark DL doesn't support loading models which have custom objects?
When I use keras to load the model, I am using
model = tf.keras.models.load_model(mask_model_file,
custom_objects={'bce_dice_loss': bce_dice_loss, 'dice_loss': dice_loss})
as mentioned here https://github.com/keras-team/keras/issues/3977
KerasImageFileTransformer doesn't support custom object loading. I am trying to run the below code which is failing.
mask_transformer = KerasImageFileTransformer(inputCol='uri', outputCol='mask', modelFile=mask_model_file, imageLoader=load_preprocess_mask_img, outputMode='vector')
masks = mask_transformer.transform(uri_df)
The stack trace for failure is:
TypeError Traceback (most recent call last)
<ipython-input-51-8fe840872c2e> in <module>()
----> 1 masks = mask_transformer.transform(uri_df)
/opt/spark-2.3.2/python/pyspark/ml/base.py in transform(self, dataset, params)
171 return self.copy(params)._transform(dataset)
172 else:
--> 173 return self._transform(dataset)
174 else:
175 raise ValueError("Params must be a param map but got %s." % type(params))
/private/var/folders/b5/9rq_y2gx4sz5k92cgzmcfz95cn72xb/T/spark-857b86db-c3f7-4376-a2cf-7b6c8c40ac74/userFiles-8ee3f002-5bbf-44f0-8897-3bda2c93b6e7/databricks_spark-deep-learning-1.2.0-spark2.3-s_2.11.jar/sparkdl/transformers/keras_image.py in _transform(self, dataset)
60 with KSessionWrap() as (sess, keras_graph):
61 graph, inputTensorName, outputTensorName = self._loadTFGraph(sess=sess,
---> 62 graph=keras_graph)
63 image_df = self.loadImagesInternal(dataset, self.getInputCol())
64 transformer = TFImageTransformer(channelOrder='RGB', inputCol=self._loadedImageCol(),
/private/var/folders/b5/9rq_y2gx4sz5k92cgzmcfz95cn72xb/T/spark-857b86db-c3f7-4376-a2cf-7b6c8c40ac74/userFiles-8ee3f002-5bbf-44f0-8897-3bda2c93b6e7/databricks_spark-deep-learning-1.2.0-spark2.3-s_2.11.jar/sparkdl/param/shared_params.py in _loadTFGraph(self, sess, graph)
169 with graph.as_default():
170 K.set_learning_phase(0) # Inference phase
--> 171 model = load_model(self.getModelFile())
172 out_op_name = tfx.op_name(model.output, graph)
173 stripped_graph = tfx.strip_and_freeze_until([out_op_name], graph, sess,
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/saving.py in load_model(filepath, custom_objects, compile)
258 raise ValueError('No model found in config file.')
259 model_config = json.loads(model_config.decode('utf-8'))
--> 260 model = model_from_config(model_config, custom_objects=custom_objects)
261
262 # set weights
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/saving.py in model_from_config(config, custom_objects)
332 '`Sequential.from_config(config)`?')
333 from ..layers import deserialize
--> 334 return deserialize(config, custom_objects=custom_objects)
335
336
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/layers/__init__.py in deserialize(config, custom_objects)
53 module_objects=globs,
54 custom_objects=custom_objects,
---> 55 printable_module_name='layer')
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
143 config['config'],
144 custom_objects=dict(list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 145 list(custom_objects.items())))
146 with CustomObjectScope(custom_objects):
147 return cls.from_config(config['config'])
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/network.py in from_config(cls, config, custom_objects)
1025 if layer in unprocessed_nodes:
1026 for node_data in unprocessed_nodes.pop(layer):
-> 1027 process_node(layer, node_data)
1028
1029 name = config.get('name')
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/network.py in process_node(layer, node_data)
984 # and building the layer if needed.
985 if input_tensors:
--> 986 layer(unpack_singleton(input_tensors), **kwargs)
987
988 def process_layer(layer_data):
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/engine/base_layer.py in __call__(self, inputs, **kwargs)
429 'You can build it manually via: '
430 '`layer.build(batch_input_shape)`')
--> 431 self.build(unpack_singleton(input_shapes))
432 self.built = True
433
/Users/vivek.vanga/anaconda3/lib/python3.6/site-packages/keras/layers/normalization.py in build(self, input_shape)
90
91 def build(self, input_shape):
---> 92 dim = input_shape[self.axis]
93 if dim is None:
94 raise ValueError('Axis ' + str(self.axis) + ' of '
TypeError: tuple indices must be integers or slices, not list
@vvivek921 I currently have the same problem. Did you find a way to load custom objects?
The KerasImageFileTransformer is meant to be used with images. Try the KerasTransformer class instead.
Example:
from sparkdl import KerasTransformer
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
# Generate random input data
num_features = 10
num_examples = 100
input_data = [{"features" : np.random.randn(num_features).tolist()} for i in range(num_examples)]
input_df = sqlContext.createDataFrame(input_data)
# Create and save a single-hidden-layer Keras model for binary classification
# NOTE: In a typical workflow, we'd train the model before exporting it to disk,
# but we skip that step here for brevity
model = Sequential()
model.add(Dense(units=20, input_shape=[num_features], activation='relu'))
model.add(Dense(units=1, activation='sigmoid'))
model_path = "/tmp/simple-binary-classification"
model.save(model_path)
# Create transformer and apply it to our input data
transformer = KerasTransformer(inputCol="features", outputCol="predictions", modelFile=model_path)
final_df = transformer.transform(input_df)
Source: https://github.com/databricks/spark-deep-learning#working-with-images-in-spark
@ghunkins yep, I know. My problem is that for instance MobileNet is using custom objects like ReLU(6.) and DepthwiseConv2D, which can be loaded in Keras within a CustomObjectScope. However sparkdl's KerasImageFileTransformer seem not to have a parameter to load those custom objects.
@MrBanhBao You can edit the load model in KerasImageFileTransformer to accept a parameter to load custom objects and rebuild and use spark dl. I however haven't tried the above. I ended up not using spark DL.