centerloss icon indicating copy to clipboard operation
centerloss copied to clipboard

  • 本文主要讲解自己对CenterLoss的一些理解,想要看原文的请戳这里 A discriminative feature learning approach for deep face recognition

  • background

    • CenterLoss提出的主要目的是对FaceNet的改进,FaceNet使用的是triple loss,该计算方法需要我们提前计算出三元组,计算量大不说,而且收敛很慢。
    • 所以CenterLoss就被提出来了
  • 自己的理解

    • 其实CenterLoss的想法也很简单,感觉和聚类有异曲同工之妙。
    • 常用的Softmax loss一般只关注于让我们不同的类可以被正确的被分开就可以了,比如下图所示的MNIST数据集使用softmaxloss 训练,提取出来features可视化之后的结果。 image
  • 我们可以看到,虽然的确不同的类别被分到了不同的地方(cluster),但是每个cluster之间距离较近,而且类内的差距较大,也就是说我们得到的features并不能通过一个较为简单的分类器将其区分开来,需要一个相对较为复杂的分类器才可以得到较好的结果。也就是说,我们得到的features并不是一个很好的特征。

  • 那么怎么改进呢?就像我们上面所说,只要让每个cluster之间的距离相对来说远一些,类内的差距小一下,那么区分起来应该就更容易了。

  • 那么具体来说怎么实现让各个cluster之间的距离远一些,类内的差距小一些呢?我们可以考虑下面的loss函数

    • 其中是我们MNIST数据集中样本对应的features,对应的是第个中心,其中
    • 通过上述函数,我们自然就约束了同一个类别中的样本到质心的距离,这样就可以让同一个cluster里面的数据更加聚集。
    • 文章中也给出了更新质心的公式
      • 其中是判断第i个sample的label是不是等于j,如果是则返回1,否则返回0
      • 代表的就是第j个类别的质心
      • 代表的是第i个样本
    • 最终我们优化的target是
  • 实现

    • 第一点:怎么计算CenterLoss
      • 一种是我自己写的,感觉很蠢。。。

        ef calculate_centerloss(x_tensor, label_tensor, centers_tensor, category_num):
        	    loss_tensor = tf.Variable(initial_value=0.0, trainable=False, dtype=tf.float32)
        	    for i in range(category_num):
        	        selected_x = tf.gather(x_tensor, 
        	                               tf.cast(
        	                                   tf.squeeze(
        	                                       tf.where(
        	                                           tf.equal(
        	                                               label_tensor, 
        	                                               [i]
        	                                           ), 
        	                                           None, 
        	                                           None
        	                                       )
        	                                   ),tf.int32))
        	        selected_y = tf.convert_to_tensor(centers_tensor[i], dtype=tf.float32)
        	        selected_x2 = tf.multiply(selected_x, selected_x)
        	        selected_y2 = tf.multiply(selected_y, selected_y)
        	        selected_xy = tf.multiply(selected_x, selected_y)
        	        distance_category = tf.abs(tf.subtract(tf.add(selected_x2, selected_y2), 2*selected_xy))
        	        distance_category = tf.abs(tf.reduce_sum(0.5 * distance_category))
        	        loss_tensor = tf.abs(tf.add(loss_tensor, distance_category))
        	    return loss_tensor
        
      • 还有一种就两行代码,其实他的center loss其实就是l2 loss,所以可以通过下面的代码计算

        ef calculate_centerloss(x_tensor, label_tensor, centers_tensor):
           centers_tensor_batch = tf.gather(centers_tensor, label_tensor)
           loss = tf.nn.l2_loss(x_tensor - centers_tensor_batch)
           return loss
        
      • 大家有兴趣可以跑一跑,两个函数的输出是一样的,不过明显下面的简洁很多

    • 第二点:怎么更新Center?
      • 这里面牵扯两点,1、怎么计算Center更新的值;2、怎么在每次反向传播的时候执行操作1。

      • 先看第一点,怎么计算center的更新值

        ef update_centers(centers, data, labels, category_num):
           centers = np.array(centers)
           data = np.array(data)
           labels = np.array(np.argmax(labels, 1))
           centers_batch = centers[labels]
           diff = centers_batch - data
           for i in range(category_num):
               cur_diff = diff[labels == i]
               cur_diff = cur_diff / (1.0 + 1.0 * len(cur_diff))
               cur_diff = cur_diff * _alpha
               for j in range(len(cur_diff)):
                   centers[i, :] -= cur_diff[j, :]
           return centers
        
      • 接下来我们关注第二点,怎么在每次反向传播的时候执行上述代码

        • tensorflow里面提供了py_func这个函数,通过该函数将普通的python function转化为可以在Graph 上执行的op,该函数的定义是:

          f.py_func(func, inp, Tout, stateful=True, name=None)
          

          func就是我们上述的update_centers, inp是tensor类型的变量,Tout是tf.float32(注意,update_centers里面接收到的参数是numpy数组)

      • 然后在每次训练的时候,也执行该函数返回,如下面所示

        wner_step = tf.py_func(update_centers, [centers_tensor, featuresmap, label_tensor, category_num], tf.float32)
         在每次更新的时候也执行我们自己的step
        , centers_value, = sess.run(
                   [train_op, owner_step], feed_dict={
                       images_tensor: train_images_batch,
                       label_tensor: train_labels_batch,
                       is_training_tensor: True,
                       centers_tensor: centers_value
                   })
        
  • 结果展示

    • 下面分别是不使用CenterLoss和$\lambda$分别等于0.01, 0.1, 1.0的结果 image image image image
  • 具体实现代码:CenterLoss

  • 博客地址