keras-onnx
keras-onnx copied to clipboard
Conversion of an fft and ifft Lambda layer
Hi together,
I am trying to convert two simple Lambda layer as a model:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Lambda
import keras2onnx
def fftLayer(x):
# expanding dimensions
frame = tf.expand_dims(x, axis=1)
# calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
stft_dat = tf.signal.rfft(frame)
# calculating magnitude and phase from the complex signal
mag = tf.abs(stft_dat)
phase = tf.math.angle(stft_dat)
# returning magnitude and phase
return [mag, phase]
def ifftLayer(self, x):
# calculating the complex representation
s1_stft = (tf.cast(x[0], tf.complex64) *
tf.exp( (1j * tf.cast(x[1], tf.complex64))))
# returning the time domain frames
out = tf.signal.irfft(s1_stft)
return out
time_dat = Input(batch_shape=(1, 512))
mag,angle = Lambda(fftLayer)(time_dat)
estimated_frames = Lambda(ifftLayer)([mag,angle])
model = Model(inputs=time_dat, outputs=estimated_frames)
# converting model
onnx_model = keras2onnx.convert_keras(model)
It is returning following error message:
WARN: No corresponding ONNX op matches the tf.op node lambda_1/irfft of type IRFFT
The generated ONNX model needs run with the custom op supports.
Traceback (most recent call last):
File "test_model_script.py", line 16, in <module>
test_class.convert_to_onnx()
File "/home/XX/create_model_class.py", line 113, in convert_to_onnx
onnx_model = keras2onnx.convert_keras(self.model)
File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/main.py", line 80, in convert_keras
parse_graph(topology, tf_graph, target_opset, output_names, output_dict)
File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/parser.py", line 841, in parse_graph
) if is_tf2 and is_tf_keras else _parse_graph_core(
File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/parser.py", line 729, in _parse_graph_core_v2
_on_parsing_tf_nodes(graph, layer_info.nodelist, varset, topology.debug_mode)
File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/parser.py", line 318, in _on_parsing_tf_nodes
out0 = varset.get_local_variable_or_declare_one(oname, infer_variable_type(o_, varset.target_opset))
File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/_parser_tf.py", line 48, in infer_variable_type
"Unable to find out a correct type for tensor type = {} of {}".format(tensor_type, tensor.name))
ValueError: Unable to find out a correct type for tensor type = <dtype: 'complex64'> of lambda_1/Cast_1:0
It tells me to use custom ops. What is the best way to do that for this case? Does the converter support complex64?
Thanks for your help in advance!
Best, Nils
The reason is that onnxruntime does not support IRFFT conversion yet, so the converter does not have corresponding tf op conversion. The converter does not support complex64. One way is that you write custom ops in onnxruntime. The other way is to rewrite FFT to separate real and image part.
So following calculation will in general not work with ONNX?
complex_stft = (tf.cast(mag, tf.complex64) * tf.exp( (1j * tf.cast(phase, tf.complex64))))
Do you have suggestions for a good tutorial on writing custom ops in onnxruntime?
Do you know any example on rewriting the FFT?
Thanks again!
Write custom op, see here Once you are done, you need also add complex64 support in keras2onnx. For rewriting FFT, I am not an expert for that.
Thanks again for you answer. It looks like there is already somebody working on an op for fft/ifft: https://github.com/onnx/onnx/pull/2625 But ONNX supports some complex64/128 representation, but not the converter? At the moment I don't have the time to work on a custom op, so I will pull the complex computations from the model for now.
@breizhn Unfortunately I got a little overwhelmed by work and didn't progressed on the fft/ifft operators PR. I should definitively resume this. But this is only for the specification part. There will remain some implementation work for the backends not already providing an FFT implementation (TF & Pytorch does provide a such implementations).