keras
keras copied to clipboard
tf.keras.models.load_model does not work as expected within MirroredStategy
System information.
- Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04.5
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.8.0
- Python version: 3.8
- Bazel version (if compiling from source):
- GPU model and memory: Tesla V100-SXM2-16GB x 4
- Exact command to reproduce: python3 load_model_within_mirrored_strategy.py
Describe the problem.
I try to load a model within MirroredStategy. I find that the loaded model within MirroredStategy is not working correctly in that only one replica is found, while there are 4 visible devices specified actually. This does not happen for the model that is directly constructed within MirroredStategy.
It is worth mentioning that the subclassing tf.keras.models.Model and tf.keras.layers.Layer are used here, which I think may be the cause of this wrong behavior. I have confirmed that loading an saved tf.keras.Sequential model works well within MirroredStategy.
Describe the current behavior.
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
model = tf.keras.models.load_model("demo")
model loaded in this way runs only on one replica.
Describe the expected behavior.
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
model = tf.keras.models.load_model("demo")
model loaded in this way should run on all 4 replicas.
- Do you want to contribute a PR? (yes/no): no
Standalone code to reproduce the issue.
load_model_within_mirrored_strategy.py
import tensorflow as tf
import tensorflow.distribute as tf_dist
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.ops import array_ops
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
def _get_current_replica_id_in_group_sync():
replica_ctx = tf_dist.get_replica_context()
if replica_ctx:
replica_id = replica_ctx.replica_id_in_sync_group
else:
replica_id = distribute_lib.get_update_replica_id()
if replica_id is None:
replica_id = array_ops.constant(0, dtype=array_ops.dtypes.int32)
return replica_id
class TestLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(TestLayer, self).__init__(**kwargs)
def call(self, inputs, training=False):
global_replica_id = _get_current_replica_id_in_group_sync()
tf.print("global_replica_id: {}".format(global_replica_id))
emb_vector = tf.zeros_like(inputs)
return emb_vector
class Demo(tf.keras.models.Model):
def __init__(self, **kwargs):
super(Demo, self).__init__(**kwargs)
self.test_layer = TestLayer()
self.dense_layer = tf.keras.layers.Dense(units=1, activation=None,
kernel_initializer="ones",
bias_initializer="zeros")
def call(self, inputs):
vector = self.test_layer(inputs)
logit = self.dense_layer(vector)
return logit, vector
def summary(self):
inputs = tf.keras.Input(shape=(10,), dtype=tf.int64)
model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
return model.summary()
@tf.function
def _step(inputs, labels, model):
logit, vector = model(inputs)
return logit, vector
def tf_dataset(keys, labels, batchsize, repeat):
dataset = tf.data.Dataset.from_tensor_slices((keys, labels))
dataset = dataset.repeat(repeat)
dataset = dataset.batch(batchsize, drop_remainder=True)
return dataset
def _dataset_fn(input_context):
global_batch_size = 16384
keys = np.ones((global_batch_size, 10))
labels = np.random.randint(low=0, high=2, size=(global_batch_size, 1))
replica_batch_size = input_context.get_per_replica_batch_size(global_batch_size)
dataset = tf_dataset(keys, labels, batchsize=replica_batch_size, repeat=1)
dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
return dataset
# Save model within MirroredStrategy scope
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
model = Demo()
model.compile()
model.summary()
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
print("-" * 30, "step ", str(i), "-" * 30)
logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model))
# model(tf.keras.Input(shape=(10,), dtype=tf.int64))
model.save("demo")
# Load model within MirroredStrategy scope
with strategy.scope():
model2 = tf.keras.models.load_model("demo")
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
print("-" * 30, "step ", str(i), "-" * 30)
logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model2))
Source code / logs.
Actual log
------------------------------ step 0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step 0 ------------------------------
global_replica_id: 0
Expected log
------------------------------ step 0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step 0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
The log is from the line tf.print("global_replica_id: {}".format(global_replica_id)) within TestLayer.call.
@tilakrayal Any progress here?
@gadagashwini Any progress here?
@yuefengz do you or anyone on tf.distribute have some clues of why loading the model under mirrored strategy gives a different result?
@rchao Any progress here?
We're waiting on an input from tf.distribute team.
Hi, I'm on the tf.distribute
team. I'll start looking into this issue, and reply once I have insights into it.
UPDATE 1:
It seems like the user-specified strategy and its replica context are (indirectly?) cached while executing or saving the model, and are absent otherwise. This seems to be occurring because, when test
exhibits unexpected behavior, the types of its strategy and replica context are incorrect:
tf_dist.get_strategy() |
tf_dist.get_replica_context() |
|
---|---|---|
Expected | MirroredStrategy |
_MirroredReplicaContext |
Actual | _DefaultDistributionStrategy |
_DefaultReplicaContext |
The incorrect types are the defaults when a compiled Graph
's strategy stack is empty, see _get_per_thread_mode
. If one compares the graph's strategy stack between (in)correct runs, it becomes evident that the strategy stack isn't being passed from _step
to TestLayer.call
:
Correct run:
_step: graph._distribution_strategy_stack=[<tensorflow.python.distribute.distribution_strategy_context._InReplicaThreadMode object at 0x7f8d18f13520>]
TestLayer.call: graph._distribution_strategy_stack=[<tensorflow.python.distribute.distribution_strategy_context._InReplicaThreadMode object at 0x7f8d18f13520>]
Incorrect run:
_step: graph._distribution_strategy_stack=[<tensorflow.python.distribute.distribution_strategy_context._InReplicaThreadMode object at 0x7fd7140598b0>]
TestLayer.call: graph._distribution_strategy_stack=[]
Next, I plan to examine the __call__
methods of keras.models.Model
and keras.layers.Layer
. Another relevant function is FuncGraph.as_default
, which normally copies the strategy stack to a nested (inner) graph from an outer one.
UPDATE 2:
It seems like another key difference between (in)correct runs is that, after loading, the function TestLayer.call
has its own Graph
, whereas it is "inlined" into the graph of _step
under normal circumstances. See this gist for more information:
https://gist.github.com/jszaday/a9c4484cc391ce98954e279434826438
@jszaday Thanks for the updates. Please let me know when there is a solution or workaround for this issue.
After looking into this extensively, I realized that during the model-saving process, a tf.function
is run with the default replica context and strategy, and the output from its tf.print
statements are hard-coded within the saved_model.pb
file. Then, whenever the model is loaded, the tf.function
may not be run—instead the output from the file is printed.
More about that here.
I spoke to @k-w-w and she reported that, indeed, tf.print
is a strange op—it's not compatible with SavedModel
because, when one saves it to a GraphDef
, it doesn't actually save the Python reference. Instead, it hardcodes the value like we're seeing here.
Yet uncertain whether there's a workaround for this; it doesn't seem like using (Python) print
guarded by not save_context.in_save_context()
works, and using tf.print
isn't recommended outside of debugging. One alternative might be to return this information as a tensor, if it's necessary.
So the loaded model graph can still be executed correctly under the Mirrored Strategy and the real problem is tf.print
?
Is there any way to verify it is working correctly after loading the model graph?
@jszaday I checked the output tensor of the loaded model graph (print(vector)
after logit, vector = strategy.run(...)
), and the log is like this:
------------------------------ step 0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:3)
2022-07-27 00:06:22.867445: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
PerReplica:{
0: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
1: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
2: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
3: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32)
}
------------------------------ step 0 ------------------------------
global_replica_id: 0
PerReplica:{
0: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
1: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
2: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
3: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32)
}
Does this mean that the loaded model graph is working correctly under the Mirrored Strategy?
Yeah, the real problem is the tf.print
as far as I can tell.
And, as for your logs, those look favorable to me; others can weigh in if there are noteworthy alternative validation mechanisms (e.g., TensorFlow Profiler?).