SRGAN icon indicating copy to clipboard operation
SRGAN copied to clipboard

Hi! I need to use tensorflow2.0 beta1 GPUs one computer for SRGAN ,but get error as follows. Because of BatchNorm layers, how to solve this problem?

Open yonghuixu opened this issue 6 years ago • 3 comments

I have find the problem.There are BatchNorm layers in SRGAN discriminator. When I delete them, the program can work. Or when I didn't use tf.distribute.MirroredStrategy, the program can work. Do tf.distribute.MirroredStrategy permit BatchNorm? How can I use GPUs?

def G_D_fn(epoch, batch_data): train_data, feature = batch_data with tf.GradientTape(persistent=True) as tape: generates = G( train_data) logits_fake = D(generates) #If I use vgg, it works normal. If I use D, it has error. logits_real = D(feature) d_loss = tl.cost.sigmoid_cross_entropy(logits, tf.ones_like(logits_fake)) g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake,tf.ones_like(logits_fake) grad = tape.gradient(g_loss, G.trainable_weights) g_optimizer.apply_gradients(zip(grad, G.trainable_weights)) grad = tape.gradient(d_loss, D.trainable_weights) d_optimizer.apply_gradients(zip(grad, D.trainable_weights)) return d_loss, g_gan_loss

print("Begin training...") devices = ['/device:GPU:{}'.format(i) for i in range(num_gpu)] strategy = tf.distribute.MirroredStrategy(devices) with strategy.scope(): train_log_dir = 'logs/gradient_tape/train' train_summary_writer = tf.summary.create_file_writer(train_log_dir) # obtain models G = get_G((batch_size, None, None, 3)) D = get_D((batch_size, None, None, 3)) lr_v = tf.Variable(lr_init) g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1) d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)

G.train()
D.train()

train_ds = get_train_data()

dist_train_ds = strategy.experimental_distribute_dataset(train_ds)

n_batch_epoch = round(n_epoch // batch_size)  
for epoch in range(n_epoch):    
    total_mse_loss = 0.0
    batch = 0
    for batch_data in dist_train_ds:
        batch += 1                 
        per_d_loss, per_g_gan_loss = strategy.experimental_run_v2(G_D_fn, args=(epoch, batch_data, ))
        total_d_loss += strategy.reduce(tf.distribute.ReduceOp.SUM, per_d_loss, axis = None)
        total_g_gan_loss += strategy.reduce(tf.distribute.ReduceOp.SUM, per_g_gan_loss, axis = None)

    d_loss = total_d_loss/batch
    g_gan_loss = total_g_gan_loss/batch

    print("g_gan:{:.6f}, d_loss: {:.9f}".format(g_gan_loss, d_loss))

yonghuixu avatar Aug 16 '19 13:08 yonghuixu

Hi, can you provide the error information?

rundiwu avatar Aug 16 '19 14:08 rundiwu

Traceback (most recent call last): File "train.py", line 680, in train() File "train.py", line 319, in train per_d_loss, per_g_loss, per_d_loss1, per_d_loss2, per_g_gan_loss, per_mse_loss, per_vgg_loss, per_style_loss = strategy.experimental_run_v2(G_D_fn, args=(epoch, batch_data, )) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 708, in experimental_run_v2 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1710, in call_for_each_replica return self._call_for_each_replica(fn, args, kwargs) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 708, in _call_for_each_replica fn, args, kwargs) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 195, in _call_for_each_replica coord.join(threads) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/training/coordinator.py", line 389, in join six.reraise(*self._exc_info_to_raise) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/six.py", line 693, in reraise raise value File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/training/coordinator.py", line 297, in stop_on_exception yield File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 189, in _call_for_each_replica **merge_kwargs) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/training/moving_averages.py", line 105, in merge_fn return update(strategy, v, value) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/training/moving_averages.py", line 96, in update return strategy.extended.update(v, update_fn, args=(value,)) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py", line 1458, in update return self._update(var, fn, args, kwargs, group) File "/home/dongwen/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/distribute/mirrored_strategy.py", line 758, in _update assert isinstance(var, values.DistributedVariable) AssertionError

------------------ 原始邮件 ------------------ 发件人: "Rundi Wu"[email protected]; 发送时间: 2019年8月16日(星期五) 晚上10:41 收件人: "tensorlayer/srgan"[email protected]; 抄送: "姐妹的海洋"[email protected];"Author"[email protected]; 主题: Re: [tensorlayer/srgan] Hi! I need to use tensorflow2.0 beta1 GPUsone computer for SRGAN ,but get error as follows. Because of BatchNorm layers,how to solve this problem? (#172)

Hi, can you provide the error information?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or mute the thread.

yonghuixu avatar Aug 17 '19 11:08 yonghuixu

Thank you,Vikra.I just try like what you said, but I got the same error.

---Original--- From: "Vikram Meena"[email protected] Date: Mon, Aug 19, 2019 17:31 PM To: "tensorlayer/srgan"[email protected]; Cc: "YonghuiXu"[email protected];"Author"[email protected]; Subject: Re: [tensorlayer/srgan] Hi! I need to use tensorflow2.0 beta1 GPUs one computer for SRGAN ,but get error as follows. Because of BatchNorm layers, how to solve this problem? (#172)

Hi, I have faced a similar issue. replace BatchNorm with BatchNorm2d. It should work.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or mute the thread.

yonghuixu avatar Aug 19 '19 09:08 yonghuixu