U-2-Net icon indicating copy to clipboard operation
U-2-Net copied to clipboard

Improving human segmentation

Open dbpprt opened this issue 3 years ago • 13 comments

Hey @xuebinqin,

I recently played a bit with your model architecture with a goal to improve the segmentation performance on the supervisely dataset and just wanted to share some of my findings openly. I used the supervisely dataset to synthesise trimaps and alpha mattes (using pymatting). I changed the loss function a bit to give different learning signals to the model (nothing proven yet - just some preliminary testing). The results seem to be quite promising however:

comparison

I was able to improve the performance significantly by:

  • using mixed precision training
  • using channels_last memory format

I additionally added quite some heavy augmentation (e.g. perspective warping, greyscale, blur, jitter) to prevent the model from overfitting. This seems to help generalizing quite effectively. The results above trained a couple of hours locally with 440 * 440 images.

The loss function is actually quite straight forward and maybe interesting for someone:

def criterion(aux, y, metadata, device):
    # aux ^= [d0, d1, d2, d3, d4, d5, d6]

    def masked_l1_loss(y_hat, y, mask):
        loss = F.l1_loss(y_hat, y, reduction='none')
        loss = (loss * mask.float()).sum()
        non_zero_elements = mask.sum()
        return loss / non_zero_elements

    mask = y[:, 0]
    smoothed_mask = gaussian_blur(
        mask.unsqueeze(dim=1), (3, 3), (1.5, 1.5)).squeeze(dim=1)
    unknown_mask = y[:, 1]

    l1_mask = torch.ones(mask.shape, device=device)
    l1_details_mask = torch.zeros(mask.shape, device=device)

    # i synthesised some detailed masks using pymatting.github.io
    # by synthesising trimaps from segmentation masks and use these
    # in an additional loss to let the model learn the unknown areas
    # between foreground and background. this is not perfect as the generated
    # trimaps and masks are not super accurate, but it seems to go in the right
    # direction.
    detailed_masks = [x['detailed_masks'] for x in metadata]
    for idx, detailed_mask in enumerate(detailed_masks):
        if not detailed_mask:
            l1_mask[idx] = l1_mask[idx] - unknown_mask[idx]
        else:
            l1_details_mask[idx] = unknown_mask[idx]

    loss = 0
    for output in aux:
        loss += masked_l1_loss(output, mask, l1_mask)
        # this loss should give some learning signals to focus on unknown areas
        loss += 3 * masked_l1_loss(output, mask, l1_details_mask)
        # i'm not quite sure if this loss gives the right incentive, the idea
        # is to blur the segmentation mask a bit to reduce background bleeding
        # caused by bad labels, preliminary results seem to be quite ok.
        loss += F.mse_loss(output, smoothed_mask)

    return loss

dbpprt avatar Mar 14 '21 11:03 dbpprt

Hi, Dennis,

Yes, your results look promising. We know that U^2-Net is able to get good performance on matting task, which has been proved by others. But your method of combining unsupervised matting and training is also very inspiring. Thanks for your works.

Regards, Xuebin

On Sun, Mar 14, 2021 at 7:23 PM Dennis Bappert @.***> wrote:

Hey @xuebinqin https://github.com/xuebinqin,

I recently played a bit with your model architecture with a goal to improve the segmentation performance on the supervisely dataset. I used the supervisely dataset to synthesise trimaps and alpha mattes. I changed the loss function a bit to give different learning signals to the model (nothing proven yet - just some preliminary testing). The results seem to be quite promising however:

[image: comparison] https://user-images.githubusercontent.com/7000188/111066003-1622d900-84bd-11eb-99e5-b94db6ac7d34.PNG

I was able to improve the performance significantly by:

  • using mixed precision training
  • using channels_memory format

I additionally added quite some heavy augmentation (e.g. perspective warping, greyscale, blur, jitter) to prevent the model from overfitting. This seems to help generalizing quite effectively. The results above trained a couple of hours locally with 440 * 440 images.

The loss function is actually quite straight forward and maybe interesting for someone:

def criterion(aux, y, metadata, device): # aux ^= [d0, d1, d2, d3, d4, d5, d6]

def masked_l1_loss(y_hat, y, mask):
    loss = F.l1_loss(y_hat, y, reduction='none')
    loss = (loss * mask.float()).sum()
    non_zero_elements = mask.sum()
    return loss / non_zero_elements

mask = y[:, 0]
smoothed_mask = gaussian_blur(
    mask.unsqueeze(dim=1), (3, 3), (1.5, 1.5)).squeeze(dim=1)
unknown_mask = y[:, 1]

l1_mask = torch.ones(mask.shape, device=device)
l1_details_mask = torch.zeros(mask.shape, device=device)

# i synthesised some detailed masks using pymatting.github.io
# by synthesising trimaps from segmentation masks and use these
# in an additional loss to let the model learn the unknown areas
# between foreground and background. this is not perfect as the generated
# trimaps and masks are not super accurate, but it seems to go in the right
# direction.
detailed_masks = [x['detailed_masks'] for x in metadata]
for idx, detailed_mask in enumerate(detailed_masks):
    if not detailed_mask:
        l1_mask[idx] = l1_mask[idx] - unknown_mask[idx]
    else:
        l1_details_mask[idx] = unknown_mask[idx]

loss = 0
for output in aux:
    loss += masked_l1_loss(output, mask, l1_mask)
    # this loss should give some learning signals to focus on unknown areas
    loss += 3 * masked_l1_loss(output, mask, l1_details_mask)
    # i'm not quite sure if this loss gives the right incentive, the idea
    # is to blur the segmentation mask a bit to reduce background bleeding
    # caused by bad labels, preliminary results seem to be quite ok.
    loss += F.mse_loss(output, smoothed_mask)

return loss

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/xuebinqin/U-2-Net/issues/178, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADSGORM4QBB4XAAQRW2BIWDTDSMDDANCNFSM4ZE35PCA .

-- Xuebin Qin PhD Department of Computing Science University of Alberta, Edmonton, AB, Canada Homepage:https://webdocs.cs.ualberta.ca/~xuebin/

xuebinqin avatar Mar 15 '21 08:03 xuebinqin

Hi @dennisbappert great work! Can you explain how channels_last memory format exactly helps?

I had tried with some augmentations from albumentations earlier however the results degraded and predictions become un-confident. Can you share you augmentation code for clarity?

bluesky314 avatar Mar 15 '21 09:03 bluesky314

Hey @bluesky314,

using the channels_last memory format nearly doubled the training performance for me.

train.py

if cfg.trainer.channels_last is True:
  model = model.to(memory_format=torch.channels_last)

collate_function.py

class CollateFunction:
    def __init__(self, transforms, channels_last):
        self.transforms = transforms
        self.channels_last = channels_last

    def __call__(self, batch):
        tensor, target_tensor, metadata = None, None, []

        for i, sample in enumerate(batch):
            if self.transforms is not None:
                sample = self.transforms(*sample)

            if tensor is None:
                h, w = sample[0].size
                memory_format = torch.channels_last if self.channels_last else torch.contiguous_format

                tensor = torch.zeros((len(batch), 3, h, w), dtype=torch.uint8).contiguous(
                    memory_format=memory_format)

                # channels_last is not necessary here
                # note, the targets contain 3 channels:
                #   - semantic maps, foreground (from trimap), unknown (from trimap)
                target_tensor = torch.zeros((len(batch), 3, h, w), dtype=torch.uint8).contiguous(
                    memory_format=torch.contiguous_format)

            # this should not be np.array but np.isarray, however there is a bug in PyTorch
            # because the image is not writeable
            # this bug causes millions of warnings, this is why we're sticking to np.array(..., copy=True)
            # which is however less efficient

            x, y, sample_metadata = sample
            x = np.array(x).transpose(2, 0, 1)  # C x H x W
            y = np.array(y).transpose(2, 0, 1)  # C x H x W

            tensor[i] += torch.from_numpy(x)
            target_tensor[i] += torch.from_numpy(y)
            metadata.append(sample_metadata)

        return tensor, target_tensor, metadata

And don't forget to train in mixed precision...

train.py

        with autocast(enabled=cfg.trainer.amp):
            y_hat, aux_outputs = model(x)
            loss, aux = criterion(aux_outputs, y, metadata, device)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

In order to properly use augmentations it is very important to augment both the image and the mask (for all spatial transformations of the image).

This is a base set for one dataset (hydra yaml config):

- identifier: supervisely
        images_path: ./dataset/supervisely/images/
        masks_path: ./dataset/supervisely/masks/
        detailed_masks: False
        rgba_masks: False
        weight: .5
        transforms:
          train:
            - _target_: lib.data.transforms.RandomRotation
              degrees: 30
            - _target_: lib.data.transforms.RandomResizedCrop
              size: 448
              scale: [ .8, 1.2 ]
            - _target_: lib.data.transforms.RandomHorizontalFlip
              flip_prob: .5
            - _target_: lib.data.transforms.RandomGrayscale
              p: .25
          val:
            - _target_: lib.data.transforms.Resize
              size: 448
              keep_aspect_ratio: True

I added parts of the aisegment dataset to my training pipeline, there I'm using a couple of more augmentations (portraits are less difficult for the model).

# doing the same augmentation on both the mask and the image
# in this case the mask has 3 channels, containing the trimap, too
# so we can augment the mask and the trimap in one go.

class RandomRotation(object):
    def __init__(self, degrees, resample=False, expand=False, center=None, fill=None):
        self.resample = resample
        self.expand = expand
        self.center = center
        self.fill = fill
        self.degrees = (-degrees, degrees)

    def __call__(self, image, target):
        angle = random.uniform(self.degrees[0], self.degrees[1])
        image = F.rotate(image, angle, self.resample, self.expand, self.center, self.fill)
        target = F.rotate(target, angle, self.resample, self.expand, self.center, self.fill)
        return image, target

After a bit more training I had to change the loss weights a bit to stabilize the training a bit. The unsupervised matting labels have to be treated quite 'softly' due to their noisy nature. A proper distribution between refined mattes and blurred segmentation masks lets the model generalize to the hair region without hard targets. But overall the results start to look extremely promising even with a handful of epochs...

dbpprt avatar Mar 15 '21 09:03 dbpprt

Thanks! Will have to look into this abit. You mentioned 'e.g. perspective warping, greyscale, blur, jitter' augmentations but did not show any of those above. Did you end up applying them?

bluesky314 avatar Mar 15 '21 09:03 bluesky314

Yes but very selectively e.g. perspective warping only for portraits and not for the supervisely dataset. I currently dropped color jittering as I synthesized a couple of more images and my dataset is now quite large.

- identifier: aisegment
        images_path: ./dataset/aisegment/images/
        masks_path: ./dataset/aisegment/masks/
        detailed_masks: False
        rgba_masks: True
        weight: .33
        transforms:
          train:
            - _target_: lib.data.transforms.RandomPerspective
              p: .3
            - _target_: lib.data.transforms.RandomRotation
              degrees: 30
            - _target_: lib.data.transforms.RandomResizedCrop
              size: 448
              scale: [ .8, 1.2 ]
            - _target_: lib.data.transforms.RandomHorizontalFlip
              flip_prob: .5
            - _target_: lib.data.transforms.RandomGrayscale
              p: .25
            - _target_: lib.data.transforms.RandomGaussianSmoothing
              p: .2
          val:
            - _target_: lib.data.transforms.Resize
              size: 448
              keep_aspect_ratio: True

dbpprt avatar Mar 15 '21 09:03 dbpprt

Perspective Warping produces samples like this (taken from the training, 2nd & 3rd rows are the predictions):

Capture

Warping shows that the model is not super stable to spatial transformations. I guess if you want to do sth like background removal on videos (frame by frame) you could use a siamese-like training approach to enforce spatial consistency:

  • Predict on sample A
  • Predict with model B on warped sample A,
  • Unwarp the prediction and minimize the l1 or l2 loss between both predictions
  • Requirements: model A & B share weights, perspective warping must be differentiable.
  • It could be bad as the model can collapse and just return 1 regardless of the inputs which would satisfy the loss constraint. This could also be done unsupervised and I guess it should help achieving spatial consistency. But just thinking out loud here, because it seems that several augmentations produce visually bad predictions.

Another topic: I'm still playing with gaussian smoothing in my loss and as augmentation to minimize background bleeding. My labels are extremely noisy due to estimated alpha mattes, however smoothing seems to be a key component to produce higher quality predictions.

And a last idea (where I would appreciate some additional thoughts): I'm currently trying to implement a discriminator to improve visual quality of unknown regions in a self-supervised fashion:

  • The mix of my labels produces quite a good loss distribution which allows the following assumption: low loss samples are visually good (high quality alpha matte prediction).
  • I use these low loss samples to train the discriminator (positive samples) but not using the alpha matte but the matte stitched together with a background.
  • Then typical GAN training...

dbpprt avatar Mar 15 '21 10:03 dbpprt

Thanks, will go through this :)

bluesky314 avatar Mar 16 '21 11:03 bluesky314

I published my training code here and one of the preliminary pretrained models.

@xuebinqin I kept the model untouched so the weights I'm providing are compatible to existing applications.

dbpprt avatar Mar 16 '21 19:03 dbpprt

Hi,Dennis,

Thanks for your great work. We just include your repo in our README, please feel free to give any suggestions: https://github.com/xuebinqin/U-2-Net/blob/master/README.md

On Wed, Mar 17, 2021 at 3:50 AM Dennis Bappert @.***> wrote:

I published my training code here https://github.com/dennisbappert/u-2-net-portrait and one of the preliminary pretrained models.

@xuebinqin https://github.com/xuebinqin I kept the model untouched so the weights I'm providing are compatible to existing applications.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/xuebinqin/U-2-Net/issues/178#issuecomment-800557447, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADSGORODXUJXA3JH3P3ZCDTTD6ZA5ANCNFSM4ZE35PCA .

-- Xuebin Qin PhD Department of Computing Science University of Alberta, Edmonton, AB, Canada Homepage:https://webdocs.cs.ualberta.ca/~xuebin/

xuebinqin avatar Mar 16 '21 23:03 xuebinqin

Thanks for adding my repo to the README, looks fine for me.

dbpprt avatar Mar 19 '21 14:03 dbpprt

@dennisbappert Hi thanks for the really good idea. Just want to know why do you choose l1 loss as your training loss instead of bce. did you do any experiments compared ce vs mse?

Sparknzz avatar Apr 22 '21 14:04 Sparknzz

@dennisbappert hi dennis, I tested on your demo.py with your checkpoint, it seems that square_pad is harmful, I tested with this image: https://i.ibb.co/ZGPFCZj/1-5-Pedestrian.jpg the result is much better if we keep the image size(without square_pad).

anguoyang avatar Mar 10 '22 01:03 anguoyang

Can you please update the weights link as it is not working @dbpprt

AmrElsersy avatar Aug 12 '23 23:08 AmrElsersy