HoHoNet
HoHoNet copied to clipboard
why i get so bad sem result ??
input image: 'assets/pano_asmasuxybohhcj.png'
get the sem result:
and my code is: import os import argparse import importlib
import cv2 from natsort import natsorted
import numpy as np
import torch import torch.nn as nn import torch.nn.functional as F
from lib.config import config, update_config, infer_exp_id
if name == 'main':
# Parse args & config
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--cfg', default='config/s2d3d_sem/HOHO_depth_dct_efficienthc_TransEn1_h1024_fold1_resnet101rgb.yaml')
parser.add_argument('--pth')
parser.add_argument('--out')
parser.add_argument('--vis_dir', default=True)
parser.add_argument('--y', action='store_true')
parser.add_argument('--test_hw', type=int, nargs='*')
parser.add_argument('opts',
help='Modify config options using the command-line',
default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
device = 'cuda' if config.cuda else 'cpu'
if config.cuda and config.cuda_benchmark:
torch.backends.cudnn.benchmark = False
# Init global variable
if not args.pth:
from glob import glob
exp_id = infer_exp_id(args.cfg)
exp_ckpt_root = os.path.join(config.ckpt_root, exp_id)
args.pth = natsorted(glob(os.path.join(exp_ckpt_root, 'ep*pth')))[-1]
print(f'No pth given, inferring the trained pth: {args.pth}')
# Init network
model_file = importlib.import_module(config.model.file)
model_class = getattr(model_file, config.model.modelclass)
net = model_class(**config.model.kwargs).to(device)
net.load_state_dict(torch.load(args.pth))
net = net.to(device).eval()
# Start eval
cm = 0
num_classes = config.model.kwargs.modalities_config.SemanticSegmenter.num_classes
with torch.no_grad():
color = cv2.imread('assets/pano_asmasuxybohhcj.png')
# color = cv2.imread('assets/1.jpg')
x = torch.from_numpy(color).permute(2, 0, 1)[None].float()/255.
if x.shape[2:] != config.dataset.common_kwargs.hw:
# x = F.interpolate(x, size=config.dataset.common_kwargs.hw, mode='bilinear', align_corners=False)
x = torch.nn.functional.interpolate(x, size=config.dataset.common_kwargs.hw, mode='area')
x = x.to(device)
pred_sem = net.infer(x)['sem']
# Visualization
if args.vis_dir:
import matplotlib.pyplot as plt
from imageio import imwrite
cmap = (plt.get_cmap('gist_rainbow')(np.arange(num_classes) / num_classes)[...,:3] * 255).astype(np.uint8)
vis_sem = cmap[pred_sem[0].argmax(0).cpu().numpy()]
color = cv2.resize(color, (vis_sem.shape[1], vis_sem.shape[0]))
vis_sem = (color * 0.2 + vis_sem * 0.8).astype(np.uint8)
cv2.imwrite('result.jpg', vis_sem)
cv2.imshow('seg', vis_sem)
cv2.waitKey(0)
i also check the test_sem.py and infer_sem.ipynb.