mmpretrain icon indicating copy to clipboard operation
mmpretrain copied to clipboard

Adaptation Towards Multi-Task Inference

Open victoic opened this issue 3 years ago • 5 comments
trafficstars

Checklist

  • I have searched related issues but cannot get the expected help. ☑️
  • I have read related documents and don't know what to do. ☑️

Describe the question you meet

I have a dataset with multiple outputs, namely car color and model, and I have adapted the CrossEntropy function to allow for k-dimensional cls_score. It is currently training. However I am lost in how to adapt the model for it to correctly return multiple predicted labels (predicted color and predicted model).

I did not find any example of this in the documentation, could anyone point me towards an example of multi-task classification with mmcls?

I am training on the vit32 model.

Post related information

  1. Other code you modified in the mmcls folder.

Here is my modified CrossEntropyLoss class, other than the Dataset class it is the only modification done:

@LOSSES.register_module(force=True)
class MultiTaskCrossEntropyLoss(nn.Module):
    """Cross entropy loss.
    Args:
        num_tasks (int): Number of k dimensions. The input on the foward will
        be reshaped into B x C/num_tasks x num_tasks. Defaults to 1.
        use_sigmoid (bool): Whether the prediction uses sigmoid
            of softmax. Defaults to False.
        use_soft (bool): Whether to use the soft version of CrossEntropyLoss.
            Defaults to False.
        reduction (str): The method used to reduce the loss.
            Options are "none", "mean" and "sum". Defaults to 'mean'.
        loss_weight (float):  Weight of the loss. Defaults to 1.0.
        class_weight (List[float], optional): The weight for each class with
            shape (C), C is the number of classes. Default None.
        pos_weight (List[float], optional): The positive weight for each
            class with shape (C), C is the number of classes. Only enabled in
            BCE loss when ``use_sigmoid`` is True. Default None.
    """

    def __init__(self,
                 num_tasks=1,
                 use_sigmoid=False,
                 use_soft=False,
                 reduction='mean',
                 loss_weight=1.0,
                 class_weight=None,
                 pos_weight=None):
        super(MultiTaskCrossEntropyLoss, self).__init__()
        self.use_sigmoid = use_sigmoid
        self.use_soft = use_soft
        assert not (
            self.use_soft and self.use_sigmoid
        ), 'use_sigmoid and use_soft could not be set simultaneously'

        self.reduction = reduction
        self.loss_weight = loss_weight
        self.class_weight = class_weight
        self.pos_weight = pos_weight

        if self.use_sigmoid:
            self.cls_criterion = binary_cross_entropy
        elif self.use_soft:
            self.cls_criterion = soft_cross_entropy
        else:
            self.cls_criterion = cross_entropy

        self.num_tasks = num_tasks

    def forward(self,
                cls_score,
                label,
                weight=None,
                avg_factor=None,
                reduction_override=None,
                **kwargs):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)

        if self.class_weight is not None:
            class_weight = cls_score.new_tensor(self.class_weight)
        else:
            class_weight = None

        # only BCE loss has pos_weight
        if self.pos_weight is not None and self.use_sigmoid:
            pos_weight = cls_score.new_tensor(self.pos_weight)
            kwargs.update({'pos_weight': pos_weight})
        else:
            pos_weight = None
        
        ##
        
        ## RESHAPE OF THE INPUT FOR MULTI TASK LOSS CALCULATION
        if self.num_tasks > 1:
          dims = cls_score.shape
          cls_score = cls_score.reshape(dims[0], int(dims[1]/self.num_tasks), self.num_tasks)

        loss_cls = self.loss_weight * self.cls_criterion(
            cls_score,
            label,
            weight,
            class_weight=class_weight,
            reduction=reduction,
            avg_factor=avg_factor,
            **kwargs)
        return loss_cls

victoic avatar Jan 04 '22 19:01 victoic

It seems that your model output a [batch, task1_channel + task2_channel ....... ] vector. if you want to get the prediction of your model, you must add a MultiTaskHead like ClsHead, notice the simple_test function.

Ezra-Yu avatar Jan 05 '22 03:01 Ezra-Yu

It seems that your model output a [batch, task1_channel + task2_channel ....... ] vector. if you want to get the prediction of your model, you must add a MultiTaskHead like ClsHead, notice the simple_test function.

Thanks for the reply. Since I'm working with VIT32, I added a MultiTasksVisionTransformerClsHead class. The class is shown below. Would that be all I need for multi task inference? I noticed I would not be able to use mmcls.apis.inference_model correctly, since L86 of inference.py gets only the first position of the output.

Here is my MultiTaskVisionTransformerClsHead:

@HEADS.register_module(force=True)
class MultiTaskVisionTransformerClsHead(VisionTransformerClsHead):
    """Vision Transformer classifier head.
    Args:
        num_classes (int): Number of categories excluding the background
            category.
        in_channels (int): Number of channels in the input feature map.
        hidden_dim (int): Number of the dimensions for hidden layer. Only
            available during pre-training. Default None.
        act_cfg (dict): The activation config. Only available during
            pre-training. Defaults to Tanh.
    """

    def __init__(self,
                 num_tasks = 1
                 *args,
                 **kwargs):
        super(MultiTaskVisionTransformerClsHead, self).__init__(
            *args, **kwargs)
        self.num_tasks = num_tasks

    def simple_test(self, x, softmax=True, post_process=True):
        """Inference without augmentation.
        Args:
            x (tuple[tuple[tensor, tensor]]): The input features.
                Multi-stage inputs are acceptable but only the last stage will
                be used to classify. Every item should be a tuple which
                includes patch token and cls token. The cls token will be used
                to classify and the shape of it should be
                ``(num_samples, in_channels)``.
            softmax (bool): Whether to softmax the classification score.
            post_process (bool): Whether to do post processing the
                inference results. It will convert the output to a list.
        Returns:
            Tensor | list: The inference results.
                - If no post processing, the output is a tensor with shape
                  ``(num_samples, num_classes)``.
                - If post processing, the output is a multi-dimentional list of
                  float and the dimensions are ``(num_samples, num_classes)``.
        """
        x = self.pre_logits(x)
        cls_score = self.layers.head(x)
        
        if cls_score is not None:
          dims = cls_score.shape
          cls_score = cls_score.reshape((dims[0], int(dims[1]/num_tasks), int(num_tasks)))

        if softmax:
            pred = (
                F.softmax(cls_score, dim=1) if cls_score is not None else None)
        else:
            pred = cls_score

        if post_process:
            return self.post_process(pred)
        else:
            return pred

victoic avatar Jan 05 '22 12:01 victoic

No, a MultiTaskVisionTransformerClsHead is not enough for multi task inference. you are right that mmcls.apis.inference_model should be modified. the post_process in head and mmcls.apis.test should be modified too.

It is recommended to use a MutilTaskHead, which contains a list of heads, each head predicts a task. it would be friendly if the classes_numbers in the tasks are different, and the post_process and loss need not be modified. this MutilTaskHead could return a dict in head.simple_test; and you can handle the dict output in mmcls.apis.inference_model and mmcls.apis.test, so that your modification does not affect other codes in mmcls.

Ezra-Yu avatar Jan 06 '22 03:01 Ezra-Yu

@Ezra-Yu @victoic please check https://github.com/open-mmlab/mmclassification/pull/675, this is my first contribution so i would be glad to have your feedbacks on this !

Thanks

piercus avatar Jan 26 '22 23:01 piercus

I'm working on multi-task learning too, I wonder if your experiment is supported by literature. Could you please share some ? @victoic

Li-Qingyun avatar Jul 13 '22 13:07 Li-Qingyun

This issue will be closed as it is inactive, feel free to re-open it if necessary.

tonysy avatar Dec 12 '22 15:12 tonysy