ternarynet icon indicating copy to clipboard operation
ternarynet copied to clipboard

Reproduced ResNet18 CIFAR10 result is 10% lower than reported

Open csyhhu opened this issue 6 years ago • 6 comments

Hi @czhu95 ,

Thanks for providing the codes!

Recently I use your codes to ternarize a ResNet18 using CIFAR10. Firstly I use tensorpack to train a ResNet18 to validation error as 0.083. However, when I apply this as initial status and ternarize (using the example codes), as I use the default delta t=0.05 in your code, the validation error is always around 0.1843. I tried other t but it is still around 0.18, which is about 10% lower than your paper report.

Is there any tricks or mistake I made?

Best regards, Shangyu

csyhhu avatar May 14 '18 04:05 csyhhu

@csyhhu Do you have any update on the accuracy? I got 87% on CIFAR10, still 4 points lower than the paper.

blueardour avatar Jun 11 '19 04:06 blueardour

Hi @blueardour I don't have any update since then.

csyhhu avatar Jun 11 '19 04:06 csyhhu

Hi. @csyhhu @blueardour @czhu95 Thanks for providing the code! I also got 10% accuracy lower than reported. The main problem I have is that some of the trained values for Wp/Wn can be negative, which resulted in two weights having the same sign. Would that be a big deal or I can just ignore it? Do I need to fix it? Is there anything I am missing? I modified from resnet from tensorpack tutorial and applied configs from this repo Sorry, I am new to quantization problems. This might be a dumb question. I would appraciate if you can answer it for me.

Model (ResNet 34 on CIFAR 10): ` def inputs(self): return [tf.TensorSpec([None, 32, 32, 3], tf.float32, 'input'), tf.TensorSpec([None], tf.int32, 'label')]

def build_graph(self, image, label):
    image = image / 128.0
    assert tf.test.is_gpu_available()
    image = tf.transpose(image, [0, 3, 1, 2])

    blocks = [3,4,6,3]

    def new_get_variable(v):
        # don't binarize first and last layer
        if not v.op.name.endswith('W') or 'conv0' in v.op.name or 'fct' in v.op.name:
            return v
        else:
            logger.info("Quantizing weight {}".format(v.op.name))
            return ternarize(v, args.t)
            #return v

    with remap_variables(new_get_variable), argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm],
                  data_format='channels_first'), \
            argscope(Conv2D, use_bias=False):
        logits = (LinearWrap(image)
                  .Conv2D('conv0', 64, 7, strides=2, activation=BNReLU, padding='VALID')
                  .MaxPooling('pool0', 3, strides=2, padding='SAME')
                  .apply2(resnet_group, 'group0', resnet_basicblock, 64, blocks[0], 1)
                  .apply2(resnet_group, 'group1', resnet_basicblock, 128, blocks[1], 2)
                  .apply2(resnet_group, 'group2', resnet_basicblock, 256, blocks[2], 2)
                  .apply2(resnet_group, 'group3', resnet_basicblock, 512, blocks[3], 2)
                  .GlobalAvgPooling('gap')
                  .FullyConnected('linear', 1000)())

    tf.nn.softmax(logits, name='output')

    cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
    cost = tf.reduce_mean(cost, name='cross_entropy_loss')

    wrong = tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, 1)), tf.float32, name='wrong_vector')
    # monitor training error
    add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

    # weight decay on all W of fc layers
    wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
                                      480000, 0.2, True)
    wd_cost = tf.multiply(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
    add_moving_summary(cost, wd_cost)

    add_param_summary(('.*/W', ['histogram']))   # monitor W
    return tf.add_n([cost, wd_cost], name='cost')

def resnet_basicblock(l, ch_out, stride): shortcut = l #l = BatchNorm('bn1', l) #l = tf.nn.relu(l) l = Conv2D('conv1', l, ch_out, 3, strides=stride, activation=BNReLU) #l = Conv2D('conv1', l, ch_out, 3, strides=stride) l = Conv2D('conv2', l, ch_out, 3, activation=get_bn(zero_init=True)) out = l + resnet_shortcut(shortcut, ch_out, stride, activation=get_bn(zero_init=False)) return tf.nn.relu(out)

def resnet_shortcut(l, n_out, stride, activation=tf.identity): data_format = get_arg_scope()['Conv2D']['data_format'] n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3] if n_in != n_out: # change dimension when channel is not the same return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) else: return l

def resnet_group(name, l, block_func, features, count, stride): with tf.variable_scope(name): for i in range(0, count): with tf.variable_scope('block{}'.format(i)): l = block_func(l, features, stride if i == 0 else 1) return l `

Final Train Output: ^[[32m[0703 23:52:34 @base.py:275]^[[0m Start Epoch 400 ... ^[[32m[0703 23:52:48 @base.py:285]^[[0m Epoch 400 (global_step 156000) finished, time:13.9 seconds. ^[[32m[0703 23:52:48 @graph.py:73]^[[0m Running Op sync_variables/sync_variables_from_main_tower ... ^[[32m[0703 23:52:50 @saver.py:79]^[[0m Model saved to train_log/cifar10-resnet34-new/model-156000. ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m QueueInput/queue_size: 50 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m cross_entropy_loss: 0.089525 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv1/Wn: 2.6743 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv1/Wp: 0.89064 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv2/Wn: 1.1225 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv2/Wp: 0.99695 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv1/Wn: 1.2538 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv1/Wp: 0.98401 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv2/Wn: 0.80803 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv2/Wp: 0.81799 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv1/Wn: 0.86712 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv1/Wp: 0.549 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv2/Wn: 0.40106 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv2/Wp: 0.70978 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv1/Wn: 0.85778 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv1/Wp: 0.82855 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv2/Wn: 0.88484 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv2/Wp: 0.73929 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/convshortcut/Wn: 0.73418 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/convshortcut/Wp: 0.53777 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv1/Wn: 0.33549 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv1/Wp: 0.39828 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv2/Wn: 0.41539 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv2/Wp: 0.31004 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wn: 0.29682 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wp: 0.44211 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wn: 0.29682 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wp: 0.44211 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv2/Wn: 0.32866 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv2/Wp: 0.47317 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv1/Wn: 0.33542 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv1/Wp: 0.17248 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv2/Wn: -0.036823 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv2/Wp: 0.13485 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv1/Wn: 0.82662 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv1/Wp: 0.6845 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv2/Wn: 0.75725 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv2/Wp: 0.79772 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/convshortcut/Wn: 0.26197 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/convshortcut/Wp: 0.40541 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv1/Wn: -0.12812 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv1/Wp: 0.23041 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv2/Wn: -0.0099889 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv2/Wp: 0.13627 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv1/Wn: 0.12555 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv1/Wp: 0.085846 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv2/Wn: 0.12783 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv2/Wp: 0.14273 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv1/Wn: 0.085759 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv1/Wp: 0.10036 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv2/Wn: 0.12674 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv2/Wp: 0.10762 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv1/Wn: 0.016841 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv1/Wp: 0.069272 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv2/Wn: -0.063544 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv2/Wp: 0.16016 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv1/Wn: 0.11573 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv1/Wp: 0.057985 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv2/Wn: -0.19735 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv2/Wp: 0.07127 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv1/Wn: 0.63527 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv1/Wp: 0.19355 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv2/Wn: 0.31411 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv2/Wp: -0.1947 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/convshortcut/Wn: 0.54178 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/convshortcut/Wp: 0.81051 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv1/Wn: 0.4934 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv1/Wp: 0.44481 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv2/Wn: 0.70687 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv2/Wp: -0.068037 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv1/Wn: -0.84213 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv1/Wp: 0.28623 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv2/Wn: 0.357 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv2/Wp: -0.71154 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m linear/Wn: 0.76216 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m linear/Wp: 0.86959 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m train_error: 0.02892 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m validation_cost: 0.62721 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m validation_error: 0.139 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m wd_cost: 30406 ^[[32m[0703 23:52:52 @group.py:48]^[[0m Callbacks took 3.280 sec in total. ModelSaver: 1.75 seconds; InferenceRunner: 1.42 seconds ^[[32m[0703 23:52:52 @base.py:289]^[[0m Training has finished!

yy665 avatar Jul 03 '19 22:07 yy665

@Yulun-Yao Hi, I'm very sorry that I've gaven up the TTN. I take weeks and tried different optimizer strategy as well as modified the gradient based on empirecal experiences gathered so far, I stilled found the traning unstable and not able to recover the accuracy.

Recently, I moved to LQ-net and Dorefa. I always obtained accuracy better than the paper reported ones without too much efforts on many scanorios. Even for a2w1 or a1w1 bit configurations, I got better result than TTN.

blueardour avatar Jul 08 '19 00:07 blueardour

@Yulun-Yao Hi, I'm very sorry that I've gaven up the TTN. I take weeks and tried different optimizer strategy as well as modified the gradient based on empirecal experiences gathered so far, I stilled found the traning unstable and not able to recover the accuracy.

Recently, I moved to LQ-net and Dorefa. I always obtained accuracy better than the paper reported ones without too much efforts on many scanorios. Even for a2w1 or a1w1 bit configurations, I got better result than TTN.

Thank you for your reply and suggestions!

Btw, have you ever encountered the problem I mentioned above? (Negative scaling factors resulted in both weights having the same sign). If you did encounter it, were you able to fix it? Would you mind sharing your model and parameters?

yy665 avatar Jul 08 '19 02:07 yy665

Hi @Yulun-Yao , sorry for replying late. I also give up TTQ. For quantization problems, maybe you can add my wechat: csyhhu for further discuss.

csyhhu avatar Jul 09 '19 02:07 csyhhu