addons icon indicating copy to clipboard operation
addons copied to clipboard

LAMB optimizer fails in MultiWorkerMirroredStrategy

Open pstjohn opened this issue 5 years ago • 6 comments

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): RHEL 7.6
  • TensorFlow version and how it was installed (source or binary): TF 2.1, (IBM WML CE 1.7.3)
  • TensorFlow-Addons version and how it was installed (source or binary): 0.9.1 source install
  • Python version: 3.6.10
  • Is GPU used? (yes/no): yes

Describe the bug Using LAMB in a MultiWorkerMirroredStrategy fails with the ambiguous error message

2020-05-27 17:07:57.478805: F tensorflow/core/framework/tensor_shape.cc:345] Check failed: size >= 0 (-8648 vs. 0)
2020-05-27 17:08:01.205512: F tensorflow/core/framework/tensor_shape.cc:345] Check failed: size >= 0 (-8651 vs. 0)

I don't have a great concise code sample -- if there was a minimal MultiWorkerMirroredStrategy example somewhere I'd be happy to try it out.

But, the same code works (1) in a single-node, 6-GPU MirroredStrategy distribution (using LAMB) (2) in a two-node, 12-GPU MultiWorkerMirroredStrategy, using tfa.optimizers.AdamW.

Other info / logs

This is on a ppc64le system, run via an LSF queue.

pstjohn avatar May 27 '20 21:05 pstjohn

Hi @pstjohn, thanks for the report. It would be great if you could provide a minimal code snippet about MultiWorkerMirroredStrategy so that we can diagnose where the problem is!

WindQAQ avatar May 28 '20 18:05 WindQAQ

Will do -- having some trouble replicating with a minimal example, so I'll close for now. If I can get a clean MultiWorkerMirroredStrategy code snippet that replicates I'll reopen.

pstjohn avatar May 31 '20 16:05 pstjohn

OK @WindQAQ, I think I've got a minimal example. I expect this is an interaction of mixed precision and the LAMB optimizer under a multi-worker setting, as I haven't been able to get this to error in a single-node with mixed_precision, or on multiple nodes without mixed precision.

This code works using the tfa 'AdamW' optimizer (and other, built-in tf optimizers) but fails with LAMB:

import os
import re
import argparse
import shutil
import subprocess
import json

# Initialize the TF_CONFIG environment variable based on process rank 
# (From https://code.ornl.gov/olcf-analytics/summit/distributed-deep-learning-
# examples/tree/master/examples/tensorflow)
get_cnodes = "echo $(cat {} | sort | uniq | grep -v batch | grep -v login)".format(os.environ['LSB_DJOB_HOSTFILE'])
cnodes = subprocess.check_output(get_cnodes, shell=True)
cnodes = str(cnodes)[2:-3].split(' ')
nodes_list = [c + ":2222" for c in cnodes] # Add a port number

# Set the TF_CONFIG environment variable to configure the cluster setting.
index = int(os.environ['PMIX_RANK'])
tf_config = json.dumps({
    'cluster': {
        'worker': nodes_list
    },
    'task': {'type': 'worker', 'index': index} 
})
os.environ['TF_CONFIG'] = tf_config

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_addons.optimizers as tfa_optimizers

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
    communication=tf.distribute.experimental.CollectiveCommunication.NCCL)

from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

# optimizer = tfa_optimizers.AdamW(learning_rate=1E-3, weight_decay=1E-3). # WORKS
optimizer = tfa_optimizers.LAMB(1E-3). # FAILS

input_arrays = tf.data.Dataset.from_generator(
    lambda: np.random.randint(50, size=(1, 32)),
    output_types=(tf.int32), output_shapes=(32,))

targets = tf.data.Dataset.from_generator(
    lambda: np.random.randint(50, size=(1, 32)),
    output_types=(tf.int32), output_shapes=(32,))

dataset = tf.data.Dataset.zip((input_arrays, targets)).repeat().batch(64*12)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset_no_auto_shard = dataset.with_options(options)


with strategy.scope():
    inputs = layers.Input(shape=(None,), dtype=tf.int32, batch_size=None)
    embeddings = layers.Embedding(50, 32)(inputs)
    outputs = layers.Dense(50)(embeddings)
    model = tf.keras.Model(inputs, outputs, name='model')
    
model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    optimizer=optimizer)

model.fit(dataset_no_auto_shard, steps_per_epoch=10, epochs=3, verbose=1)

This code works with the AdamW optimizer from tensorflow_addons, but fails with LAMB. The specific error for this code is

2020-05-31 12:15:05.426691: F tensorflow/core/framework/tensor_shape.cc:345] Check failed: size >= 0 (-79 vs. 0)

This is being run on OCLF's SUMMIT computer on 2 nodes and 6 GPUs each.

pstjohn avatar May 31 '20 16:05 pstjohn

Hi @pstjohn, thanks for the update! Unfortunately, I currently do not have the access to multi-nodes and multi-GPUs machines. Is it possible to attach the full log file as well? Thank you!

WindQAQ avatar Jun 01 '20 01:06 WindQAQ

Looks like it also happens if I just run it on a single node (with 6 GPUs) Here's the full output: https://gist.github.com/pstjohn/025df83800ea920e5fa114753092138d (unfortunately not much more information). Is there a way to get a more verbose output?

pstjohn avatar Jun 01 '20 12:06 pstjohn

Thanks for the report @pstjohn. I'll be reviewing https://github.com/tensorflow/addons/pull/1770 in the next couple of days and hopefully we can add some tests that we can use to diagnose what the issue is

seanpmorgan avatar Jun 03 '20 02:06 seanpmorgan