iree
iree copied to clipboard
KeyError for the tensorflow module with the latest iree compiler
What happened?
Traceback (most recent call last): File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tensorflow/python/saved_model/function_serialization.py", line 67, in serialize_concrete_function bound_inputs.append(node_ids[capture]) ~~~~~~~~^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/tensorflow/python/util/object_identity.py", line 136, in getitem return self._storage[self._wrap_key(key)] ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^ KeyError: <_ObjectIdentityWrapper wrapping <tf.Tensor: shape=(), dtype=resource, value=<ResourceHandle(name="SGD/learning_rate/5", device="/job:localhost/replica:0/task:0/device:CPU:0", container="Anonymous", type="tensorflow::Var", dtype and shapes : "[ DType enum: 1, Shape: [] ]")>>>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/lingzhi/_Code/KCLOpenSource/kcl/a.py", line 68, in
Steps to reproduce your issue
- Write the code (main.py) and setup the dependencies (the iree-compiler version is 20240228.815, os is macos arm64, and python version is 3.11)
# python3 -m pip install matplotlib
# python3 -m pip install --upgrade tf-nightly # Needed for stablehlo export in TF>=2.14
# python3 -m pip install iree-compiler iree-runtime iree-tools-tf
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import iree.compiler.tf
import iree.runtime
tf.random.set_seed(91)
np.random.seed(91)
NUM_CLASSES = 10
NUM_ROWS, NUM_COLS = 28, 28
BATCH_SIZE = 32
class TrainableDNN(tf.Module):
def __init__(self):
super().__init__()
# Create a Keras model to train.
inputs = tf.keras.layers.Input((NUM_COLS, NUM_ROWS, 1))
x = tf.keras.layers.Flatten()(inputs)
x = tf.keras.layers.Dense(128)(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Dense(10)(x)
outputs = tf.keras.layers.Softmax()(x)
self.model = tf.keras.Model(inputs, outputs)
# Create a loss function and optimizer to use during training.
self.loss = tf.keras.losses.SparseCategoricalCrossentropy()
self.optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2)
@tf.function(
input_signature=[tf.TensorSpec([BATCH_SIZE, NUM_ROWS, NUM_COLS, 1])] # inputs
)
def predict(self, inputs):
return self.model(inputs, training=False)
# We compile the entire training step by making it a method on the model.
@tf.function(
input_signature=[
tf.TensorSpec([BATCH_SIZE, NUM_ROWS, NUM_COLS, 1]), # inputs
tf.TensorSpec([BATCH_SIZE], tf.int32), # labels
]
)
def learn(self, inputs, labels):
# Capture the gradients from forward prop...
with tf.GradientTape() as tape:
probs = self.model(inputs, training=True)
loss = self.loss(labels, probs)
# ...and use them to update the model's weights.
variables = self.model.trainable_variables
gradients = tape.gradient(loss, variables)
self.optimizer.apply_gradients(zip(gradients, variables))
return loss
# ------------------------------
# Compile the Model with IREE
# ------------------------------
exported_names = ["predict", "learn"]
backend_choice = "llvm-cpu (CPU)" # @param [ "vmvx (CPU)", "llvm-cpu (CPU)", "vulkan-spirv (GPU/SwiftShader – requires additional drivers) " ]
backend_choice = backend_choice.split(" ")[0]
# Compile the TrainableDNN module
vm_flatbuffer = iree.compiler.tf.compile_module(
TrainableDNN(), target_backends=[backend_choice], exported_names=exported_names
)
backend_choice = "llvm-cpu (CPU)" # @param [ "vmvx (CPU)", "llvm-cpu (CPU)", "vulkan-spirv (GPU/SwiftShader – requires additional drivers) " ]
backend_choice = backend_choice.split(" ")[0]
compiled_model = iree.runtime.load_vm_flatbuffer(vm_flatbuffer, backend=backend_choice)
- Run the command
python3 main.py
and see the error message.
What component(s) does this issue relate to?
No response
Version information
No response
Additional context
No response
What version of TensorFlow are you using?
Sorry, my python version is Python 3.11.8, the Tensorflow version is 2.17.0-dev20240228 and the IREE version is 20240228.815.