cvnn icon indicating copy to clipboard operation
cvnn copied to clipboard

Model subclassing compatibility

Open lminer opened this issue 3 years ago • 4 comments

I've been trying to get this to work with the model subclassing API, but for some reason, the first layer of the model always expects the data to be in float32. Any idea how to get this to work?

lminer avatar Apr 22 '22 17:04 lminer

Did you add the ComplexInput layer at the start? If not TF will automatically use the tensorflow Input and cast it to float.

NEGU93 avatar Apr 22 '22 17:04 NEGU93

How would you do that in a subclassed model? I have a custom train and test step so I can't just do:

inputs = ComplexInput((1,2,3))
outputs = SubClassedModel(inputs)
model = tf.keras.Model(inputs, outputs)

lminer avatar Apr 22 '22 17:04 lminer

It is possible... did it work? I have never worked with SubCalssedModels, I don't know how they worked. What is the difference from a normal model?

NEGU93 avatar Apr 27 '22 12:04 NEGU93

I couldn't get it to work. In order to use the approach above, I would have needed too big a refactor. Subclassed models are just a pytorch like interface. You just inherit from the normal model and then build the layers in the constructor, and the implementation in call. It's the same as a custom layer, but with a model instead.

lminer avatar Apr 27 '22 16:04 lminer