HRNet-Semantic-Segmentation icon indicating copy to clipboard operation
HRNet-Semantic-Segmentation copied to clipboard

Demo code

Open yaoliUoA opened this issue 5 years ago • 8 comments

I think it is better to have some demo code in this package to visualize the segmentation output from hrnet.

yaoliUoA avatar Apr 13 '20 02:04 yaoliUoA

Hi Yao, Were you able to write some demo code that we can use to visualize the predictions?

I am trying to do the same but the code is quite dense and I have made some progress but it would help if you already have something working.

If you are still working on this, would you like to collaborate and creating some demo code, which we can then submit as a PR?

jaintarun avatar May 04 '20 00:05 jaintarun

Hi @yaoliUoA @jaintarun Were you able to write the inference code?

umairanis03 avatar Jul 08 '20 16:07 umairanis03

I think it is better to have some demo code in this package to visualize the segmentation output from hrnet,too

Linda-L avatar Dec 16 '20 11:12 Linda-L

I think it is necessary too.

MyHubTo avatar Apr 23 '21 01:04 MyHubTo

import argparse

from lib.config import config from lib.config import update_config_demo import lib.models.seg_hrnet as seg_models

import torch import torch.nn as nn import torch.backends.cudnn as cudnn import cv2 from PIL import Image import numpy as np from torch.nn import functional as F

mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] @torch.no_grad()

class FaceSeg(): def init(self,cfg_file='./experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',weights='best.pt',device = 'cpu',imgsz=[700,700], num_classes=4):#include background

cudnn related setting

update_config_demo(config)

cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED

# build model
if torch.__version__.startswith('1'):
    module = seg_models
    module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
model = module.get_seg_model(config)

dump_input = torch.rand(
    (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)

if config.TEST.MODEL_FILE:
    model_state_file = config.TEST.MODEL_FILE
else:
    print("cant find model_file: ",config.TEST.MODEL_FILE)
    exit()
    
pretrained_dict = torch.load(model_state_file)
if 'state_dict' in pretrained_dict:
    pretrained_dict = pretrained_dict['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                    if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if device != 'cpu':
    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
else:
    print("use cpu seg")
    
model.eval()
self.model=model
self.crop_size=imgsz
self.num_classes=num_classes
self.label_mapping={-1: ignore_label, 0: ignore_label, 
                      1: ignore_label, 2: ignore_label, 
                      3: ignore_label, 4: ignore_label, 
                      5: ignore_label, 6: ignore_label, 
                      7: 0, 8: 1, 9: ignore_label, 
                      10: ignore_label, 11: 2, 12: 3, 
                      13: 4, 14: ignore_label, 15: ignore_label, 
                      16: ignore_label, 17: 5, 18: ignore_label, 
                      19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
                      25: 12, 26: 13, 27: 14, 28: 15, 
                      29: ignore_label, 30: ignore_label, 
                      31: 16, 32: 17, 33: 18}

def run(self,img0): confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) #要经过转换为tensor image_nor = img0.astype(np.float32)[:, :, ::-1] image_nor = image_nor / 255.0 image_nor -= mean image_nor /= std print(img0.shape) ori_height, ori_width, _ = img0.shape

image = image_nor.copy()
stride_h = np.int(self.crop_size[0] * 1.0)
stride_w = np.int(self.crop_size[1] * 1.0)

final_pred = torch.zeros([1, self.num_classes,ori_height,ori_width])
new_img=cv2.resize(image, (self.crop_size[0],self.crop_size[1]),interpolation=cv2.INTER_LINEAR)
height, width = new_img.shape[:-1]
    
new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img)

preds = self.model(new_img)
new_size = new_img.size()
print("new size",new_size)
preds = F.interpolate(
    input=preds, size=new_size[-2:],
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)
preds=preds.exp()

preds = preds[:, :, 0:height, 0:width]

preds = F.interpolate(
    preds, (ori_height, ori_width), 
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)            
final_pred += preds
                           
visual=False
if visual:
    palette = self.get_palette(256)
    preds = np.asarray(np.argmax(preds.detach().cpu(), axis=1), dtype=np.uint8)
    for i in range(preds.shape[0]):
        pred = self.convert_label(preds[i], inverse=True)
        save_img = Image.fromarray(pred)
        save_img.putpalette(palette)
        save_img.save('test.png')

def get_palette(self, n): palette = [0] * (n * 3) for j in range(0, n): lab = j palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) i += 1 lab >>= 3 return palette

def convert_label(self, label, inverse=False): temp = label.copy() if inverse: for v, k in self.label_mapping.items(): label[temp == k] = v else: for k, v in self.label_mapping.items(): label[temp == k] = v return label if name == "main": face_segt=FaceSeg(weights="your_path/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/300_checkpoint.pth.tar") img=cv2.imread("yours.png") face_segt.run(img)

dreamlychina avatar Sep 14 '21 06:09 dreamlychina

import argparse

from lib.config import config from lib.config import update_config_demo import lib.models.seg_hrnet as seg_models

import torch import torch.nn as nn import torch.backends.cudnn as cudnn import cv2 from PIL import Image import numpy as np from torch.nn import functional as F

mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] @torch.no_grad()

class FaceSeg(): def init(self,cfg_file='./experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',weights='best.pt',device = 'cpu',imgsz=[700,700], num_classes=4):#include background

cudnn related setting

update_config_demo(config)

cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED

# build model
if torch.__version__.startswith('1'):
    module = seg_models
    module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
model = module.get_seg_model(config)

dump_input = torch.rand(
    (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)

if config.TEST.MODEL_FILE:
    model_state_file = config.TEST.MODEL_FILE
else:
    print("cant find model_file: ",config.TEST.MODEL_FILE)
    exit()
    
pretrained_dict = torch.load(model_state_file)
if 'state_dict' in pretrained_dict:
    pretrained_dict = pretrained_dict['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                    if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if device != 'cpu':
    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
else:
    print("use cpu seg")
    
model.eval()
self.model=model
self.crop_size=imgsz
self.num_classes=num_classes
self.label_mapping={-1: ignore_label, 0: ignore_label, 
                      1: ignore_label, 2: ignore_label, 
                      3: ignore_label, 4: ignore_label, 
                      5: ignore_label, 6: ignore_label, 
                      7: 0, 8: 1, 9: ignore_label, 
                      10: ignore_label, 11: 2, 12: 3, 
                      13: 4, 14: ignore_label, 15: ignore_label, 
                      16: ignore_label, 17: 5, 18: ignore_label, 
                      19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
                      25: 12, 26: 13, 27: 14, 28: 15, 
                      29: ignore_label, 30: ignore_label, 
                      31: 16, 32: 17, 33: 18}

def run(self,img0): confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) #要经过转换为tensor image_nor = img0.astype(np.float32)[:, :, ::-1] image_nor = image_nor / 255.0 image_nor -= mean image_nor /= std print(img0.shape) ori_height, ori_width, _ = img0.shape

image = image_nor.copy()
stride_h = np.int(self.crop_size[0] * 1.0)
stride_w = np.int(self.crop_size[1] * 1.0)

final_pred = torch.zeros([1, self.num_classes,ori_height,ori_width])
new_img=cv2.resize(image, (self.crop_size[0],self.crop_size[1]),interpolation=cv2.INTER_LINEAR)
height, width = new_img.shape[:-1]
    
new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img)

preds = self.model(new_img)
new_size = new_img.size()
print("new size",new_size)
preds = F.interpolate(
    input=preds, size=new_size[-2:],
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)
preds=preds.exp()

preds = preds[:, :, 0:height, 0:width]

preds = F.interpolate(
    preds, (ori_height, ori_width), 
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)            
final_pred += preds
                           
visual=False
if visual:
    palette = self.get_palette(256)
    preds = np.asarray(np.argmax(preds.detach().cpu(), axis=1), dtype=np.uint8)
    for i in range(preds.shape[0]):
        pred = self.convert_label(preds[i], inverse=True)
        save_img = Image.fromarray(pred)
        save_img.putpalette(palette)
        save_img.save('test.png')

def get_palette(self, n): palette = [0] * (n * 3) for j in range(0, n): lab = j palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) i += 1 lab >>= 3 return palette

def convert_label(self, label, inverse=False): temp = label.copy() if inverse: for v, k in self.label_mapping.items(): label[temp == k] = v else: for k, v in self.label_mapping.items(): label[temp == k] = v return label if name == "main": face_segt=FaceSeg(weights="your_path/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/300_checkpoint.pth.tar") img=cv2.imread("yours.png") face_segt.run(img)

My friend, thank you for the code you provided, but there seems to be some problems with the format, could you please provide the correct format of the code

alexanderuo avatar Jan 05 '22 08:01 alexanderuo

你对齐下就可以了,我粘贴过来就变这样了

dreamlychina avatar Jan 05 '22 08:01 dreamlychina

你对齐下就可以了,我粘贴过来就变这样了

好的,谢谢兄弟,被你发现我是个中国人了哈哈

alexanderuo avatar Jan 06 '22 09:01 alexanderuo