Ultra-Fast-Lane-Detection-v2 icon indicating copy to clipboard operation
Ultra-Fast-Lane-Detection-v2 copied to clipboard

Inference on custom dataset

Open hzhv opened this issue 2 years ago • 2 comments

To someone who's interested in testing multiple custom images (cut from video), save the following as infer.py and change the dataset path and output path, then run: python infer.py configs/your_model_config.py --test_model path/to/your/model/file

import torch, os, cv2
from pylab import *
from utils.dist_utils import dist_print
from utils.common import merge_config, get_model
import tqdm
import torchvision.transforms as transforms
from data.dataset import LaneTestDataset


def pred2coords(pred, row_anchor, col_anchor, local_width = 1, original_image_width = 1640, original_image_height = 590):
    batch_size, num_grid_row, num_cls_row, num_lane_row = pred['loc_row'].shape
    batch_size, num_grid_col, num_cls_col, num_lane_col = pred['loc_col'].shape

    max_indices_row = pred['loc_row'].argmax(1).cpu()
    # n , num_cls, num_lanes
    valid_row = pred['exist_row'].argmax(1).cpu()
    # n, num_cls, num_lanes

    max_indices_col = pred['loc_col'].argmax(1).cpu()
    # n , num_cls, num_lanes
    valid_col = pred['exist_col'].argmax(1).cpu()
    # n, num_cls, num_lanes

    pred['loc_row'] = pred['loc_row'].cpu()
    pred['loc_col'] = pred['loc_col'].cpu()

    coords = []

    row_lane_idx = [1,2]
    col_lane_idx = [0,3]

    for i in row_lane_idx:
        tmp = []
        if valid_row[0,:,i].sum() > num_cls_row / 2:
            for k in range(valid_row.shape[1]):
                if valid_row[0,k,i]:
                    all_ind = torch.tensor(list(range(max(0,max_indices_row[0,k,i] - local_width), min(num_grid_row-1, max_indices_row[0,k,i] + local_width) + 1)))
                    
                    out_tmp = (pred['loc_row'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5
                    out_tmp = out_tmp / (num_grid_row-1) * original_image_width
                    tmp.append((int(out_tmp), int(row_anchor[k] * original_image_height)))
            coords.append(tmp)

    for i in col_lane_idx:
        tmp = []
        if valid_col[0,:,i].sum() > num_cls_col / 4:
            for k in range(valid_col.shape[1]):
                if valid_col[0,k,i]:
                    all_ind = torch.tensor(list(range(max(0,max_indices_col[0,k,i] - local_width), min(num_grid_col-1, max_indices_col[0,k,i] + local_width) + 1)))
                    
                    out_tmp = (pred['loc_col'][0,all_ind,k,i].softmax(0) * all_ind.float()).sum() + 0.5

                    out_tmp = out_tmp / (num_grid_col-1) * original_image_height
                    tmp.append((int(col_anchor[k] * original_image_width), int(out_tmp)))
            coords.append(tmp)

    return coords

if __name__ == "__main__":
    torch.backends.cudnn.benchmark = True
    args, cfg = merge_config()
    cfg.batch_size = 1
    print('setting batch_size to 1 for demo generation')

    dist_print('start testing...')
    assert cfg.backbone in ['18', '34', '50', '101', '152', '50next', '101next', '50wide', '101wide']

    if cfg.dataset == 'CULane':
        cls_num_per_lane = 18
    elif cfg.dataset == 'Tusimple':
        cls_num_per_lane = 56
    else:
        raise NotImplementedError

    net = get_model(cfg)

    state_dict = torch.load(cfg.test_model, map_location='cpu')['model']
    compatible_state_dict = {}
    for k, v in state_dict.items():
        if 'module.' in k:
            compatible_state_dict[k[7:]] = v
        else:
            compatible_state_dict[k] = v

    net.load_state_dict(compatible_state_dict, strict=False)
    net.eval()

    img_transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((int(cfg.train_height/cfg.crop_ratio), cfg.train_width)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    pathname = "" # Dir of imgs to be detected
    imgs_path = os.listdir(pathname)
    i = 0
    for imgname in imgs_path:
        if imgname.endswith('.png') or imgname.endswith('.jpg') or imgname.endswith('.jpeg'):
            img = cv2.imread(pathname + '/' + imgname)
            im0 = img.copy()
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img_h, img_w = img.shape[0], img.shape[1]
            img = img_transforms(img)
            img = img[:, -cfg.train_height:, :]
            img = img.to('cuda:0')
            img = torch.unsqueeze(img, 0)

            with torch.no_grad():
                pred = net(img)
            coords = pred2coords(pred, cfg.row_anchor, cfg.col_anchor, original_image_width=img_w,
                                    original_image_height=img_h)
            for lane in coords:
                for coord in lane:
                    cv2.circle(im0, coord, 5, (0, 255, 0), -1)
            resname = " " + str(i) + '.png' # Dir of inference result
            cv2.imwrite(resname, im0)
            i += 1
            print(imgname + " finished.")
            #cv2.waitKey(0)

Many thanks to @Yutong-gannis

hzhv avatar Oct 06 '22 08:10 hzhv

[您好,非常感谢您的代码,我在我自己的图像上进行了测试,完全检测不到车道线,能帮忙猜一下是那里除了问题吗?万分感谢!! 1665843298034

kidcats avatar Oct 15 '22 14:10 kidcats

@Li-Hanzhao Good work!

@kidcats 可能是与训练数据集差异太大了

cfzd avatar Oct 18 '22 13:10 cfzd

outXXX 确实效果不好,你这个数据集分布和训练的不一致。还是自己训练吧

xiaqing10 avatar Nov 23 '22 06:11 xiaqing10

@Li-Hanzhao Good work!

@kidcats 可能是与训练数据集差异太大了

outXXX 确实效果不好,你这个数据集分布和训练的不一致。还是自己训练吧 谢谢,我自己训练看看了

kidcats avatar Nov 23 '22 06:11 kidcats

@Li-Hanzhao Good work!

@kidcats 可能是与训练数据集差异太大了

谢谢

kidcats avatar Nov 23 '22 06:11 kidcats

可以在imgs_path = os.listdir(pathname)后写一行imgs_path = sorted(imgs_path) 这样程序就会按照检测图片的顺序来输出图片,再把输出的图片拼接,就可以看作是检测视频了 `import cv2 import os

图片文件夹路径和输出视频路径

image_folder = '/path/to/images' # 图片文件夹路径 video_name = '/path/to/output/video.mp4' # 输出视频路径

获取图片文件夹中的所有图片文件名,并按照顺序进行排序

images = sorted([img for img in os.listdir(image_folder) if img.endswith(".jpg") or img.endswith(".png")])

读取第一张图片,获取图像尺寸

frame = cv2.imread(os.path.join(image_folder, images[0])) height, width, _ = frame.shape

创建视频编码器对象

fourcc = cv2.VideoWriter_fourcc(*"mp4v")#avi和mp4 修改编码就可以 video = cv2.VideoWriter(video_name, fourcc, 30.0, (width, height))

逐帧写入视频

for image in images: img_path = os.path.join(image_folder, image) frame = cv2.imread(img_path) video.write(frame)

释放视频编码器对象和关闭视频文件

video.release() cv2.destroyAllWindows()`

ennheng avatar May 06 '23 10:05 ennheng

`import cv2 import os

图片文件夹路径和输出视频路径

image_folder = '/home/ennheng/ufld/result/gaojia1/' # 图片文件夹路径 output_video = '/home/ennheng/ufld/result/gaojia1.mp4' #

获取图片文件夹中的所有图片文件名

images = [img for img in os.listdir(image_folder) if img.endswith(".png") or img.endswith(".jpg")]

根据图片的生成时间对图片进行排序

images = sorted(images, key=lambda x: os.path.getmtime(os.path.join(image_folder, x)))

获取第一张图片的宽度和高度

first_image = cv2.imread(os.path.join(image_folder, images[0])) height, width, _ = first_image.shape

定义视频编码器和输出视频对象

fourcc = cv2.VideoWriter_fourcc(*"mp4v") # 编码器 video = cv2.VideoWriter(output_video, fourcc, 30, (width, height)) # 输出视频对象

逐帧写入视频

for image_name in images: image_path = os.path.join(image_folder, image_name) frame = cv2.imread(image_path) video.write(frame) # 将帧写入视频

释放视频编码器对象和关闭视频文件

video.release() cv2.destroyAllWindows() `这样好点

ennheng avatar May 06 '23 11:05 ennheng