Training crashes with DebuggerHookConfig
I have a simple Tensorflow model which I'm training on SageMaker. It was working fine. Recently it has started crashing while training right after the first checkpoint is saved.
Following is the stack trace:
AssertionError: training/Ftrl/Ftrl/AssignAddVariableOp is not in graph
Traceback (most recent call last):
File "trainer_click_label.py", line 259, in <module>
estimator = train_and_evaluate(train_files, test_files, args.model_dir, args.epochs)
File "trainer_click_label.py", line 133, in train_and_evaluate
tf.estimator.train_and_evaluate(linear_est, train_spec=train_spec, eval_spec=eval_spec)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/training.py", line 473, in train_and_evaluate
return executor.run()
File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/training.py", line 613, in run
return self.run_local()
File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/training.py", line 714, in run_local
saving_listeners=saving_listeners)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 374, in train
loss = self._train_model(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1164, in _train_model
return self._train_model_default(input_fn, hooks, saving_listeners)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1198, in _train_model_default
saving_listeners)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/estimator.py", line 1497, in _train_with_estimator_spec
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/monitored_session.py", line 784, in run
run_metadata=run_metadata)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/monitored_session.py", line 1289, in run
run_metadata=run_metadata)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/monitored_session.py", line 1390, in run
raise six.reraise(*original_exc_info)
File "/usr/local/lib/python3.6/dist-packages/six.py", line 703, in reraise
raise value
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/monitored_session.py", line 1375, in run
return self._sess.run(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/monitored_session.py", line 1439, in run
feed_dict, options)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/monitored_session.py", line 1466, in _call_hook_before_run
request = hook.before_run(run_context)
File "/usr/local/lib/python3.6/dist-packages/smdebug/tensorflow/session.py", line 362, in before_run
run_context.original_args.fetches,
File "/usr/local/lib/python3.6/dist-packages/smdebug/tensorflow/session.py", line 343, in _filter_to_be_saved
tensor_ref.tf_obj, fetches_ops_tuple, unfilled_placeholders
File "/usr/local/lib/python3.6/dist-packages/smdebug/tensorflow/session.py", line 320, in _is_tensor_dependent_on_unfilled_placeholder
subgraph_nodes = self._get_subgraph_which_reach_fetches(fetches_ops_tuple)
File "/usr/local/lib/python3.6/dist-packages/smdebug/tensorflow/session.py", line 304, in _get_subgraph_which_reach_fetches
subgraph = tf.graph_util.extract_sub_graph(self.graph.as_graph_def(), dest_names)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/deprecation.py", line 324, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/graph_util_impl.py", line 197, in extract_sub_graph
_assert_nodes_are_present(name_to_node, dest_nodes)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/graph_util_impl.py", line 152, in _assert_nodes_are_present
assert d in name_to_node, "%s is not in graph" % d
Here is a snippet from my code:
hook_config = DebuggerHookConfig(
s3_output_path='s3://ratul-playground/sagemaker/debug_hook/search_linear',
collection_configs=[],
)
estimator = TensorFlow(
source_dir='code',
entry_point=TRAINER_FILE,
model_dir=model_dir,
code_location=OUTPUT_CODE_S3_DIR,
output_path=OUTPUT_TRAINING_S3_DIR,
train_instance_type=train_instance_type,
train_instance_count=1,
hyperparameters=hyperparameters,
sagemaker_session=sess,
role=sagemaker.get_execution_role(),
base_job_name='ranking-linear',
framework_version='2.1',
py_version='py3',
script_mode=True,
debugger_hook_config=hook_config,
tensorboard_output_config=tb_config
)
within trainer file:
initial_learning_rate = 0.2
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate, decay_steps=10000, decay_rate=0.96, staircase=True
)
optimizer = tf.keras.optimizers.Ftrl(
learning_rate=lr_schedule, l2_regularization_strength=0.001
)
linear_est = tf.estimator.LinearClassifier(
feature_columns=feature_columns,
n_classes=2,
optimizer=optimizer,
model_dir=os.path.join(model_dir, "checkpoints"),
config=tf.estimator.RunConfig(save_checkpoints_secs=100),
)
train_input_fn = make_input_fn(train_files, num_epochs=num_epochs)
eval_input_fn = make_input_fn(test_files, num_epochs=1, shuffle=False)
# hook = smd.SessionHook.create_from_json_file()
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=20000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=1000, throttle_secs=0)
tf.estimator.train_and_evaluate(linear_est, train_spec=train_spec, eval_spec=eval_spec)
Right now the only way around is disabling the debug-hook by passing debugger_hook_config=False.
Please advice.
@ratulray Thanks for reporting this. Would you be able to provide minimal reprodcuible training script and sample data that we can use to reproduce this.
@ratulray I have provided a fix for this. It would be great if you can provide us with a script and data and steps to repro and we can check if the fix works. Or other option is you can actually take the fix and verify it yourself. I can help you with required steps for this, steps should be minimal. Please let me know whatever works for you.
thanks, I'll try to create a minimal script to reproduce the error and upload here soon.
I can reproduce the error with this script. Can you please test with your fix?
import os
import argparse
import tensorflow as tf
import math
import numpy as np
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
return parser.parse_known_args()
if __name__ == "__main__":
args, _ = parse_args()
x = np.random.randint(0, 100, size=(50000, 5)).astype(np.float32)
y = np.random.randint(0, 1, size=50000).astype(np.float32)
batch_size = 2
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x}, y=y, batch_size=batch_size, num_epochs=3, shuffle=True
)
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": np.random.randint(0, 100, size=(500, 5)).astype(np.float32)},
y=np.random.randint(0, 1, size=500).astype(np.float32),
batch_size=batch_size,
num_epochs=1,
shuffle=False,
)
initial_learning_rate = 0.2
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate, decay_steps=10000, decay_rate=0.96, staircase=True
)
optimizer = tf.keras.optimizers.Ftrl(
learning_rate=lr_schedule, l2_regularization_strength=0.001
)
linear_est = tf.estimator.LinearClassifier(
feature_columns=[tf.feature_column.numeric_column("x", shape=[5])],
n_classes=2,
optimizer=optimizer,
model_dir=os.path.join(args.model_dir, "checkpoints"),
config=tf.estimator.RunConfig(save_checkpoints_secs=20),
)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=50000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(linear_est, train_spec, eval_spec)
Thanks @ratulray I think PR:https://github.com/awslabs/sagemaker-debugger/pull/324 should fix this.
If you want to test on your code you can modify it like this:
import os
import argparse
import tensorflow as tf
import math
import numpy as np
#### Add this
import smdebug.tensorflow as smd
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="/tmp")#os.environ.get("SM_MODEL_DIR"))
return parser.parse_known_args()
if __name__ == "__main__":
args, _ = parse_args()
x = np.random.randint(0, 100, size=(50000, 5)).astype(np.float32)
y = np.random.randint(0, 1, size=50000).astype(np.float32)
batch_size = 2
#### Add this and below
save_config = smd.SaveConfig(save_interval=90)
hook = smd.EstimatorHook(out_dir="/tmp/smdebug", save_config=save_config)
#hook.get_collection("losses").include('head/losses/weighted_loss/value:0')
##### added till this point ######
train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": x}, y=y, batch_size=batch_size, num_epochs=3, shuffle=True
)
eval_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
x={"x": np.random.randint(0, 100, size=(500, 5)).astype(np.float32)},
y=np.random.randint(0, 1, size=500).astype(np.float32),
batch_size=batch_size,
num_epochs=1,
shuffle=False,
)
initial_learning_rate = 0.2
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate, decay_steps=10000, decay_rate=0.96, staircase=True
)
optimizer = tf.keras.optimizers.Ftrl(
learning_rate=lr_schedule, l2_regularization_strength=0.001
)
linear_est = tf.estimator.LinearClassifier(
feature_columns=[tf.feature_column.numeric_column("x", shape=[5])],
n_classes=2,
optimizer=optimizer,
model_dir=os.path.join(args.model_dir, "checkpoints"),
config=tf.estimator.RunConfig(save_checkpoints_secs=20),
)
## Add hooks=[hook] in below 2 specs.
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=50000, hooks=[hook])
#eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, hooks=[hook])
tf.estimator.train_and_evaluate(linear_est, train_spec, eval_spec)