keras icon indicating copy to clipboard operation
keras copied to clipboard

migrate keras2 code to keras3.0 with pytorch backend, got inplace error

Open sataliulan opened this issue 1 year ago • 4 comments

hi,there I tried to migrate my keras 2.15 code to keras 3.0 (just like the source code below), the code ran well in keras 2 with backend tensorflow 2.15, but got an 'inplace' error in keras3.0 with backend of pytorch 2.1.1. Any ideas ?

source code in keras2 with tf2.15 backend:

from keras.models import Model
from keras import layers
inputs = layers.Input(shape=(48,1))
\# inputs = layers.Input(shape=(X_train.shape[1], X_train.shape[2]))

lstm = layers.Bidirectional(layers.LSTM(64,return_sequences=True,dropout=0.5))(inputs, training = True)
lstm = layers.Bidirectional(layers.LSTM(16,return_sequences=False,dropout=0.5))(lstm, training = True)
dense = layers.Dense(50)(lstm)
out10 = layers.Dense(1,name='quantile10')(dense)
out50 = layers.Dense(1,name='quantile50')(dense)
out90 = layers.Dense(1,name='quantile90')(dense)
Lmodel = Model(inputs=inputs,outputs=[out10,out50,out90])
Lmodel.summary()
losses = [lambda y,f: q_loss(0.1,y,f), lambda y,f:q_loss(0.5,y,f), lambda y,f:q_loss(0.9,y,f)]
Lmodel.compile(loss= losses,optimizer = 'adam', loss_weights=[0.3,0.3,0.3])
 \# model test
import tensorflow as tf
tempX = tf.random.normal((4,48,1))
tempY= tf.ones(4,)
Lmodel.fit(x=tempX,y={'quantile10':tempY,"quantile50":tempY,"quantile90":tempY},
epochs=1,batch_size=1,verbose=2)
`
# source code in keras3.0 with pytorch 2.1.1 backend:
`# !export KERAS_BACKEND = 'torch'

import os
os.environ['KERAS_BACKEND']='torch'
import keras
print(keras.\__version__)
from keras.models import Model
from keras import layers
inputs = layers.Input(shape=(48,1))
\# inputs = layers.Input(shape=(X_train.shape[1], X_train.shape[2]))
lstm = layers.Bidirectional(layers.LSTM(64,return_sequences=True,dropout=0.5))(inputs, training = True)
lstm = layers.Bidirectional(layers.LSTM(16,return_sequences=False,dropout=0.5))(lstm, training = True)
dense = layers.Dense(50)(lstm)
out10 = layers.Dense(1,name='quantile10')(dense)
out50 = layers.Dense(1,name='quantile50')(dense)
out90 = layers.Dense(1,name='quantile90')(dense)
Lmodel = Model(inputs=inputs,outputs=[out10,out50,out90])
Lmodel.summary()
losses = [lambda y,f: q_loss(0.1,y,f), lambda y,f:q_loss(0.5,y,f), lambda y,f:q_loss(0.9,y,f)]
Lmodel.compile(loss= losses,optimizer = 'adam', loss_weights=[0.3,0.3,0.3])
\# model test
tempX= torch.randn((4,48,1))
tempY = torch.ones(4)
Lmodel.fit(x=tempX,y={'quantile10':tempY,"quantile50":tempY,"quantile90":tempY},
epochs=1,batch_size=1,verbose=2)

屏幕截图 2023-12-06 221023

the error log(as shown in picture above :

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/sdd/anaconda3/envs/torch_lightning/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: Exception encountered when calling LSTMCell.call().

[1mOutput 0 of UnbindBackward0 is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.[0m

Arguments received by LSTMCell.call():
  • inputs=torch.Tensor(shape=torch.Size([1, 128]), dtype=float32)
  • states=('torch.Tensor(shape=torch.Size([1, 16]), dtype=float32)', 'torch.Tensor(shape=torch.Size([1, 16]), dtype=float32)')
  • training=True

sataliulan avatar Dec 06 '23 14:12 sataliulan

Hi,

Thanks for reporting the issue, I was trying to reproduce the error, I could not find the q_loss defined in the code, please find the attached Gist here and provide sample reproducible code.

sachinprasadhs avatar Dec 06 '23 18:12 sachinprasadhs

loss for keras 2.15

def q_loss(q,y,f):#q: quntile ,y :true label, f: pred lablel error = y-f # return keras.ops.mean(keras.ops.maximum(q*error,(q-1)error),axis=-1) return keras.backend.mean(keras.backend.maximum(qerror,(q-1)*error),axis=-1)

loss for keras 3.0

def q_loss(q,y,f):#q: quntile ,y :true label, f: pred lablel error = y-f return keras.ops.mean(keras.ops.maximum(q*error,(q-1)error),axis=-1) #return keras.backend.mean(keras.backend.maximum(qerror,(q-1)*error),axis=-1)

sataliulan avatar Dec 07 '23 00:12 sataliulan

Hi,

Thanks for reporting the issue, I was trying to reproduce the error, I could not find the q_loss defined in the code, please find the attached Gist here and provide sample reproducible code. I rewrote my codes in your colab, and shared the running Gist here. looking forward your reply.

sataliulan avatar Dec 07 '23 01:12 sataliulan

I was able to run the code without any error in Keras 3.0.2, please find the below Gist and let us know if there is anything I'm missing.

https://colab.sandbox.google.com/gist/sachinprasadhs/6329e51aa49fa697c09185f912b5049e/18895.ipynb

sachinprasadhs avatar Dec 27 '23 23:12 sachinprasadhs

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] avatar Apr 26 '24 01:04 github-actions[bot]

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

github-actions[bot] avatar May 10 '24 01:05 github-actions[bot]

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar May 10 '24 01:05 google-ml-butler[bot]