keras-onnx icon indicating copy to clipboard operation
keras-onnx copied to clipboard

Conversion of an fft and ifft Lambda layer

Open breizhn opened this issue 5 years ago • 5 comments

Hi together,

I am trying to convert two simple Lambda layer as a model:

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Lambda
import keras2onnx

def fftLayer(x):
    # expanding dimensions
    frame = tf.expand_dims(x, axis=1)
    # calculating the fft over the time frames. rfft returns NFFT/2+1 bins.
    stft_dat = tf.signal.rfft(frame)
    # calculating magnitude and phase from the complex signal
    mag = tf.abs(stft_dat)
    phase = tf.math.angle(stft_dat)
    # returning magnitude and phase
    return [mag, phase]

def ifftLayer(self, x):
    # calculating the complex representation
    s1_stft = (tf.cast(x[0], tf.complex64) * 
                    tf.exp( (1j * tf.cast(x[1], tf.complex64))))
    # returning the time domain frames
    out = tf.signal.irfft(s1_stft)   
    return out

time_dat = Input(batch_shape=(1, 512))
mag,angle = Lambda(fftLayer)(time_dat)
estimated_frames = Lambda(ifftLayer)([mag,angle])
model = Model(inputs=time_dat, outputs=estimated_frames)
# converting model
onnx_model = keras2onnx.convert_keras(model)

It is returning following error message:

WARN: No corresponding ONNX op matches the tf.op node lambda_1/irfft of type IRFFT
      The generated ONNX model needs run with the custom op supports.
Traceback (most recent call last):
  File "test_model_script.py", line 16, in <module>
    test_class.convert_to_onnx()
  File "/home/XX/create_model_class.py", line 113, in convert_to_onnx
    onnx_model = keras2onnx.convert_keras(self.model)
  File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/main.py", line 80, in convert_keras
    parse_graph(topology, tf_graph, target_opset, output_names, output_dict)
  File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/parser.py", line 841, in parse_graph
    ) if is_tf2 and is_tf_keras else _parse_graph_core(
  File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/parser.py", line 729, in _parse_graph_core_v2
    _on_parsing_tf_nodes(graph, layer_info.nodelist, varset, topology.debug_mode)
  File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/parser.py", line 318, in _on_parsing_tf_nodes
    out0 = varset.get_local_variable_or_declare_one(oname, infer_variable_type(o_, varset.target_opset))
  File "/home/xx/anaconda3/envs/tf22env/lib/python3.7/site-packages/keras2onnx/_parser_tf.py", line 48, in infer_variable_type
    "Unable to find out a correct type for tensor type = {} of {}".format(tensor_type, tensor.name))
ValueError: Unable to find out a correct type for tensor type = <dtype: 'complex64'> of lambda_1/Cast_1:0

It tells me to use custom ops. What is the best way to do that for this case? Does the converter support complex64?

Thanks for your help in advance!

Best, Nils

breizhn avatar Jun 29 '20 12:06 breizhn

The reason is that onnxruntime does not support IRFFT conversion yet, so the converter does not have corresponding tf op conversion. The converter does not support complex64. One way is that you write custom ops in onnxruntime. The other way is to rewrite FFT to separate real and image part.

jiafatom avatar Jun 29 '20 13:06 jiafatom

So following calculation will in general not work with ONNX?

complex_stft = (tf.cast(mag, tf.complex64) * tf.exp( (1j * tf.cast(phase, tf.complex64)))) 

Do you have suggestions for a good tutorial on writing custom ops in onnxruntime?

Do you know any example on rewriting the FFT?

Thanks again!

breizhn avatar Jun 29 '20 14:06 breizhn

Write custom op, see here Once you are done, you need also add complex64 support in keras2onnx. For rewriting FFT, I am not an expert for that.

jiafatom avatar Jun 29 '20 14:06 jiafatom

Thanks again for you answer. It looks like there is already somebody working on an op for fft/ifft: https://github.com/onnx/onnx/pull/2625 But ONNX supports some complex64/128 representation, but not the converter? At the moment I don't have the time to work on a custom op, so I will pull the complex computations from the model for now.

breizhn avatar Jun 29 '20 14:06 breizhn

@breizhn Unfortunately I got a little overwhelmed by work and didn't progressed on the fft/ifft operators PR. I should definitively resume this. But this is only for the specification part. There will remain some implementation work for the backends not already providing an FFT implementation (TF & Pytorch does provide a such implementations).

jeremycochoy avatar Jul 03 '20 09:07 jeremycochoy