HoHoNet icon indicating copy to clipboard operation
HoHoNet copied to clipboard

why i get so bad sem result ??

Open zuixiaosanlang opened this issue 1 year ago • 0 comments

input image: 'assets/pano_asmasuxybohhcj.png'

get the sem result: image

and my code is: import os import argparse import importlib

import cv2 from natsort import natsorted

import numpy as np

import torch import torch.nn as nn import torch.nn.functional as F

from lib.config import config, update_config, infer_exp_id

if name == 'main':

# Parse args & config
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--cfg', default='config/s2d3d_sem/HOHO_depth_dct_efficienthc_TransEn1_h1024_fold1_resnet101rgb.yaml')
parser.add_argument('--pth')
parser.add_argument('--out')
parser.add_argument('--vis_dir', default=True)
parser.add_argument('--y', action='store_true')
parser.add_argument('--test_hw', type=int, nargs='*')
parser.add_argument('opts',
                    help='Modify config options using the command-line',
                    default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
update_config(config, args)
device = 'cuda' if config.cuda else 'cpu'

if config.cuda and config.cuda_benchmark:
    torch.backends.cudnn.benchmark = False

# Init global variable
if not args.pth:
    from glob import glob
    exp_id = infer_exp_id(args.cfg)
    exp_ckpt_root = os.path.join(config.ckpt_root, exp_id)
    args.pth = natsorted(glob(os.path.join(exp_ckpt_root, 'ep*pth')))[-1]
    print(f'No pth given,  inferring the trained pth: {args.pth}')

# Init network
model_file = importlib.import_module(config.model.file)
model_class = getattr(model_file, config.model.modelclass)
net = model_class(**config.model.kwargs).to(device)
net.load_state_dict(torch.load(args.pth))
net = net.to(device).eval()

# Start eval
cm = 0
num_classes = config.model.kwargs.modalities_config.SemanticSegmenter.num_classes
with torch.no_grad():
    color = cv2.imread('assets/pano_asmasuxybohhcj.png')
    # color = cv2.imread('assets/1.jpg')
    x = torch.from_numpy(color).permute(2, 0, 1)[None].float()/255.
    if x.shape[2:] != config.dataset.common_kwargs.hw:
        # x = F.interpolate(x, size=config.dataset.common_kwargs.hw, mode='bilinear', align_corners=False)
        x = torch.nn.functional.interpolate(x, size=config.dataset.common_kwargs.hw, mode='area')
    x = x.to(device)

    pred_sem = net.infer(x)['sem']

    # Visualization
    if args.vis_dir:
        import matplotlib.pyplot as plt
        from imageio import imwrite
        cmap = (plt.get_cmap('gist_rainbow')(np.arange(num_classes) / num_classes)[...,:3] * 255).astype(np.uint8)

        vis_sem = cmap[pred_sem[0].argmax(0).cpu().numpy()]

        color = cv2.resize(color, (vis_sem.shape[1], vis_sem.shape[0]))
        vis_sem = (color * 0.2 + vis_sem * 0.8).astype(np.uint8)
        cv2.imwrite('result.jpg', vis_sem)

        cv2.imshow('seg', vis_sem)
        cv2.waitKey(0)

i also check the test_sem.py and infer_sem.ipynb.

zuixiaosanlang avatar Aug 17 '23 01:08 zuixiaosanlang