HRNet-Semantic-Segmentation
HRNet-Semantic-Segmentation copied to clipboard
Demo code
I think it is better to have some demo code in this package to visualize the segmentation output from hrnet.
Hi Yao, Were you able to write some demo code that we can use to visualize the predictions?
I am trying to do the same but the code is quite dense and I have made some progress but it would help if you already have something working.
If you are still working on this, would you like to collaborate and creating some demo code, which we can then submit as a PR?
Hi @yaoliUoA @jaintarun Were you able to write the inference code?
I think it is better to have some demo code in this package to visualize the segmentation output from hrnet,too
I think it is necessary too.
import argparse
from lib.config import config from lib.config import update_config_demo import lib.models.seg_hrnet as seg_models
import torch import torch.nn as nn import torch.backends.cudnn as cudnn import cv2 from PIL import Image import numpy as np from torch.nn import functional as F
mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] @torch.no_grad()
class FaceSeg(): def init(self,cfg_file='./experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',weights='best.pt',device = 'cpu',imgsz=[700,700], num_classes=4):#include background
cudnn related setting
update_config_demo(config)
cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED
# build model
if torch.__version__.startswith('1'):
module = seg_models
module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
model = module.get_seg_model(config)
dump_input = torch.rand(
(1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
)
if config.TEST.MODEL_FILE:
model_state_file = config.TEST.MODEL_FILE
else:
print("cant find model_file: ",config.TEST.MODEL_FILE)
exit()
pretrained_dict = torch.load(model_state_file)
if 'state_dict' in pretrained_dict:
pretrained_dict = pretrained_dict['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
if k[6:] in model_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
if device != 'cpu':
gpus = list(config.GPUS)
model = nn.DataParallel(model, device_ids=gpus).cuda()
else:
print("use cpu seg")
model.eval()
self.model=model
self.crop_size=imgsz
self.num_classes=num_classes
self.label_mapping={-1: ignore_label, 0: ignore_label,
1: ignore_label, 2: ignore_label,
3: ignore_label, 4: ignore_label,
5: ignore_label, 6: ignore_label,
7: 0, 8: 1, 9: ignore_label,
10: ignore_label, 11: 2, 12: 3,
13: 4, 14: ignore_label, 15: ignore_label,
16: ignore_label, 17: 5, 18: ignore_label,
19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
25: 12, 26: 13, 27: 14, 28: 15,
29: ignore_label, 30: ignore_label,
31: 16, 32: 17, 33: 18}
def run(self,img0): confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) #要经过转换为tensor image_nor = img0.astype(np.float32)[:, :, ::-1] image_nor = image_nor / 255.0 image_nor -= mean image_nor /= std print(img0.shape) ori_height, ori_width, _ = img0.shape
image = image_nor.copy()
stride_h = np.int(self.crop_size[0] * 1.0)
stride_w = np.int(self.crop_size[1] * 1.0)
final_pred = torch.zeros([1, self.num_classes,ori_height,ori_width])
new_img=cv2.resize(image, (self.crop_size[0],self.crop_size[1]),interpolation=cv2.INTER_LINEAR)
height, width = new_img.shape[:-1]
new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img)
preds = self.model(new_img)
new_size = new_img.size()
print("new size",new_size)
preds = F.interpolate(
input=preds, size=new_size[-2:],
mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)
preds=preds.exp()
preds = preds[:, :, 0:height, 0:width]
preds = F.interpolate(
preds, (ori_height, ori_width),
mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
)
final_pred += preds
visual=False
if visual:
palette = self.get_palette(256)
preds = np.asarray(np.argmax(preds.detach().cpu(), axis=1), dtype=np.uint8)
for i in range(preds.shape[0]):
pred = self.convert_label(preds[i], inverse=True)
save_img = Image.fromarray(pred)
save_img.putpalette(palette)
save_img.save('test.png')
def get_palette(self, n): palette = [0] * (n * 3) for j in range(0, n): lab = j palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) i += 1 lab >>= 3 return palette
def convert_label(self, label, inverse=False): temp = label.copy() if inverse: for v, k in self.label_mapping.items(): label[temp == k] = v else: for k, v in self.label_mapping.items(): label[temp == k] = v return label if name == "main": face_segt=FaceSeg(weights="your_path/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/300_checkpoint.pth.tar") img=cv2.imread("yours.png") face_segt.run(img)
import argparse
from lib.config import config from lib.config import update_config_demo import lib.models.seg_hrnet as seg_models
import torch import torch.nn as nn import torch.backends.cudnn as cudnn import cv2 from PIL import Image import numpy as np from torch.nn import functional as F
mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] @torch.no_grad()
class FaceSeg(): def init(self,cfg_file='./experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',weights='best.pt',device = 'cpu',imgsz=[700,700], num_classes=4):#include background
cudnn related setting
update_config_demo(config)
cudnn.benchmark = config.CUDNN.BENCHMARK cudnn.deterministic = config.CUDNN.DETERMINISTIC cudnn.enabled = config.CUDNN.ENABLED # build model if torch.__version__.startswith('1'): module = seg_models module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d model = module.get_seg_model(config) dump_input = torch.rand( (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]) ) if config.TEST.MODEL_FILE: model_state_file = config.TEST.MODEL_FILE else: print("cant find model_file: ",config.TEST.MODEL_FILE) exit() pretrained_dict = torch.load(model_state_file) if 'state_dict' in pretrained_dict: pretrained_dict = pretrained_dict['state_dict'] model_dict = model.state_dict() pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() if k[6:] in model_dict.keys()} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) if device != 'cpu': gpus = list(config.GPUS) model = nn.DataParallel(model, device_ids=gpus).cuda() else: print("use cpu seg") model.eval() self.model=model self.crop_size=imgsz self.num_classes=num_classes self.label_mapping={-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label, 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label, 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4, 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5, 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}def run(self,img0): confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) #要经过转换为tensor image_nor = img0.astype(np.float32)[:, :, ::-1] image_nor = image_nor / 255.0 image_nor -= mean image_nor /= std print(img0.shape) ori_height, ori_width, _ = img0.shape
image = image_nor.copy() stride_h = np.int(self.crop_size[0] * 1.0) stride_w = np.int(self.crop_size[1] * 1.0) final_pred = torch.zeros([1, self.num_classes,ori_height,ori_width]) new_img=cv2.resize(image, (self.crop_size[0],self.crop_size[1]),interpolation=cv2.INTER_LINEAR) height, width = new_img.shape[:-1] new_img = new_img.transpose((2, 0, 1)) new_img = np.expand_dims(new_img, axis=0) new_img = torch.from_numpy(new_img) preds = self.model(new_img) new_size = new_img.size() print("new size",new_size) preds = F.interpolate( input=preds, size=new_size[-2:], mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS ) preds=preds.exp() preds = preds[:, :, 0:height, 0:width] preds = F.interpolate( preds, (ori_height, ori_width), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS ) final_pred += preds visual=False if visual: palette = self.get_palette(256) preds = np.asarray(np.argmax(preds.detach().cpu(), axis=1), dtype=np.uint8) for i in range(preds.shape[0]): pred = self.convert_label(preds[i], inverse=True) save_img = Image.fromarray(pred) save_img.putpalette(palette) save_img.save('test.png')def get_palette(self, n): palette = [0] * (n * 3) for j in range(0, n): lab = j palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) i += 1 lab >>= 3 return palette
def convert_label(self, label, inverse=False): temp = label.copy() if inverse: for v, k in self.label_mapping.items(): label[temp == k] = v else: for k, v in self.label_mapping.items(): label[temp == k] = v return label if name == "main": face_segt=FaceSeg(weights="your_path/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/300_checkpoint.pth.tar") img=cv2.imread("yours.png") face_segt.run(img)
My friend, thank you for the code you provided, but there seems to be some problems with the format, could you please provide the correct format of the code
你对齐下就可以了,我粘贴过来就变这样了
你对齐下就可以了,我粘贴过来就变这样了
好的,谢谢兄弟,被你发现我是个中国人了哈哈