keras icon indicating copy to clipboard operation
keras copied to clipboard

tf.keras.models.load_model does not work as expected within MirroredStategy

Open KingsleyLiu-NV opened this issue 2 years ago • 11 comments

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04.5
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.8.0
  • Python version: 3.8
  • Bazel version (if compiling from source):
  • GPU model and memory: Tesla V100-SXM2-16GB x 4
  • Exact command to reproduce: python3 load_model_within_mirrored_strategy.py

Describe the problem.

I try to load a model within MirroredStategy. I find that the loaded model within MirroredStategy is not working correctly in that only one replica is found, while there are 4 visible devices specified actually. This does not happen for the model that is directly constructed within MirroredStategy.

It is worth mentioning that the subclassing tf.keras.models.Model and tf.keras.layers.Layer are used here, which I think may be the cause of this wrong behavior. I have confirmed that loading an saved tf.keras.Sequential model works well within MirroredStategy.

Describe the current behavior.

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
    model = tf.keras.models.load_model("demo")

model loaded in this way runs only on one replica.

Describe the expected behavior.

strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
    model = tf.keras.models.load_model("demo")

model loaded in this way should run on all 4 replicas.

Contributing.

  • Do you want to contribute a PR? (yes/no): no

Standalone code to reproduce the issue.

load_model_within_mirrored_strategy.py

import tensorflow as tf
import tensorflow.distribute as tf_dist
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.ops import array_ops
import numpy as np
import os
​
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
​
def _get_current_replica_id_in_group_sync():
    replica_ctx = tf_dist.get_replica_context()
    if replica_ctx:
        replica_id = replica_ctx.replica_id_in_sync_group
    else:
        replica_id = distribute_lib.get_update_replica_id()
    if replica_id is None:
        replica_id = array_ops.constant(0, dtype=array_ops.dtypes.int32)
    return replica_id
​
class TestLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(TestLayer, self).__init__(**kwargs)
​
    def call(self, inputs, training=False):
        global_replica_id = _get_current_replica_id_in_group_sync()
        tf.print("global_replica_id: {}".format(global_replica_id))
        emb_vector  = tf.zeros_like(inputs)
        return emb_vector
​
class Demo(tf.keras.models.Model):
    def __init__(self, **kwargs):
        super(Demo, self).__init__(**kwargs)
        
        self.test_layer = TestLayer()        
        self.dense_layer = tf.keras.layers.Dense(units=1, activation=None,
                                                 kernel_initializer="ones",
                                                 bias_initializer="zeros")
​
    def call(self, inputs):
        vector = self.test_layer(inputs)
        logit = self.dense_layer(vector)
        return logit, vector
​
    def summary(self):
        inputs = tf.keras.Input(shape=(10,), dtype=tf.int64)
        model = tf.keras.models.Model(inputs=inputs, outputs=self.call(inputs))
        return model.summary()
​
@tf.function
def _step(inputs, labels, model):
    logit, vector = model(inputs)
    return logit, vector
​
def tf_dataset(keys, labels, batchsize, repeat):
    dataset = tf.data.Dataset.from_tensor_slices((keys, labels))
    dataset = dataset.repeat(repeat)
    dataset = dataset.batch(batchsize, drop_remainder=True)
    return dataset
​
def _dataset_fn(input_context):
    global_batch_size = 16384
    keys = np.ones((global_batch_size, 10))
    labels = np.random.randint(low=0, high=2, size=(global_batch_size, 1))
    replica_batch_size = input_context.get_per_replica_batch_size(global_batch_size)
    dataset = tf_dataset(keys, labels, batchsize=replica_batch_size, repeat=1)
    dataset = dataset.shard(input_context.num_input_pipelines, input_context.input_pipeline_id)
    return dataset
​
# Save model within MirroredStrategy scope
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"])
with strategy.scope():
    model = Demo()
model.compile()
model.summary()
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
    print("-" * 30, "step ", str(i), "-" * 30)
    logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model))
# model(tf.keras.Input(shape=(10,), dtype=tf.int64))
model.save("demo")
​
# Load model within MirroredStrategy scope
with strategy.scope():
    model2 = tf.keras.models.load_model("demo")
dataset = strategy.distribute_datasets_from_function(_dataset_fn)
for i, (key_tensors, replica_labels) in enumerate(dataset):
    print("-" * 30, "step ", str(i), "-" * 30)
    logit, vector = strategy.run(_step, args=(key_tensors, replica_labels, model2))

Source code / logs.

Actual log

------------------------------ step  0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step  0 ------------------------------
global_replica_id: 0

Expected log

------------------------------ step  0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)
2022-07-13 06:20:56.820402: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
------------------------------ step  0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:gpu:3)

The log is from the line tf.print("global_replica_id: {}".format(global_replica_id)) within TestLayer.call.

KingsleyLiu-NV avatar Jul 13 '22 09:07 KingsleyLiu-NV

@tilakrayal Any progress here?

KingsleyLiu-NV avatar Jul 14 '22 04:07 KingsleyLiu-NV

@gadagashwini Any progress here?

KingsleyLiu-NV avatar Jul 14 '22 09:07 KingsleyLiu-NV

@yuefengz do you or anyone on tf.distribute have some clues of why loading the model under mirrored strategy gives a different result?

rchao avatar Jul 14 '22 17:07 rchao

@rchao Any progress here?

KingsleyLiu-NV avatar Jul 21 '22 03:07 KingsleyLiu-NV

We're waiting on an input from tf.distribute team.

rchao avatar Jul 22 '22 16:07 rchao

Hi, I'm on the tf.distribute team. I'll start looking into this issue, and reply once I have insights into it.

UPDATE 1: It seems like the user-specified strategy and its replica context are (indirectly?) cached while executing or saving the model, and are absent otherwise. This seems to be occurring because, when test exhibits unexpected behavior, the types of its strategy and replica context are incorrect:

tf_dist.get_strategy() tf_dist.get_replica_context()
Expected MirroredStrategy _MirroredReplicaContext
Actual _DefaultDistributionStrategy _DefaultReplicaContext

The incorrect types are the defaults when a compiled Graph's strategy stack is empty, see _get_per_thread_mode. If one compares the graph's strategy stack between (in)correct runs, it becomes evident that the strategy stack isn't being passed from _step to TestLayer.call:

Correct run:
_step: graph._distribution_strategy_stack=[<tensorflow.python.distribute.distribution_strategy_context._InReplicaThreadMode object at 0x7f8d18f13520>]
TestLayer.call: graph._distribution_strategy_stack=[<tensorflow.python.distribute.distribution_strategy_context._InReplicaThreadMode object at 0x7f8d18f13520>]

Incorrect run:
_step: graph._distribution_strategy_stack=[<tensorflow.python.distribute.distribution_strategy_context._InReplicaThreadMode object at 0x7fd7140598b0>]
TestLayer.call: graph._distribution_strategy_stack=[]

Next, I plan to examine the __call__ methods of keras.models.Model and keras.layers.Layer. Another relevant function is FuncGraph.as_default, which normally copies the strategy stack to a nested (inner) graph from an outer one.

UPDATE 2: It seems like another key difference between (in)correct runs is that, after loading, the function TestLayer.call has its own Graph, whereas it is "inlined" into the graph of _step under normal circumstances. See this gist for more information: https://gist.github.com/jszaday/a9c4484cc391ce98954e279434826438

jszaday avatar Jul 22 '22 18:07 jszaday

@jszaday Thanks for the updates. Please let me know when there is a solution or workaround for this issue.

KingsleyLiu-NV avatar Jul 26 '22 23:07 KingsleyLiu-NV

After looking into this extensively, I realized that during the model-saving process, a tf.function is run with the default replica context and strategy, and the output from its tf.print statements are hard-coded within the saved_model.pb file. Then, whenever the model is loaded, the tf.function may not be run—instead the output from the file is printed. More about that here.

I spoke to @k-w-w and she reported that, indeed, tf.print is a strange op—it's not compatible with SavedModel because, when one saves it to a GraphDef, it doesn't actually save the Python reference. Instead, it hardcodes the value like we're seeing here.

Yet uncertain whether there's a workaround for this; it doesn't seem like using (Python) print guarded by not save_context.in_save_context() works, and using tf.print isn't recommended outside of debugging. One alternative might be to return this information as a tensor, if it's necessary.

jszaday avatar Jul 26 '22 23:07 jszaday

So the loaded model graph can still be executed correctly under the Mirrored Strategy and the real problem is tf.print? Is there any way to verify it is working correctly after loading the model graph?

KingsleyLiu-NV avatar Jul 26 '22 23:07 KingsleyLiu-NV

@jszaday I checked the output tensor of the loaded model graph (print(vector) after logit, vector = strategy.run(...)), and the log is like this:

------------------------------ step  0 ------------------------------
global_replica_id: Tensor("demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:0)
global_replica_id: Tensor("replica_1/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:1)
global_replica_id: Tensor("replica_2/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:2)
global_replica_id: Tensor("replica_3/demo/test_layer/replica_id_in_sync_group:0", shape=(), dtype=int32, device=/job:localhost/replica:0/task:0/device:GPU:3)
2022-07-27 00:06:22.867445: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
PerReplica:{
  0: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
  1: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
  2: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
  3: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32)
}
------------------------------ step  0 ------------------------------
global_replica_id: 0
PerReplica:{
  0: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
  1: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
  2: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32),
  3: tf.Tensor(
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]], shape=(4096, 10), dtype=float32)
}

Does this mean that the loaded model graph is working correctly under the Mirrored Strategy?

KingsleyLiu-NV avatar Jul 27 '22 00:07 KingsleyLiu-NV

Yeah, the real problem is the tf.print as far as I can tell.

And, as for your logs, those look favorable to me; others can weigh in if there are noteworthy alternative validation mechanisms (e.g., TensorFlow Profiler?).

jszaday avatar Jul 27 '22 03:07 jszaday