R-YOLOv4 icon indicating copy to clipboard operation
R-YOLOv4 copied to clipboard

数据集格式问题

Open PJPomPom opened this issue 3 years ago • 9 comments

大佬,请问您一下,你的数据集是标签txt中的格式,和r3det中txt的不太一样,您用的是哪个版本的rolabelimg软件呢,我转换后的txt和您的不太一样,谢谢

PJPomPom avatar Oct 27 '21 08:10 PJPomPom

不好意思晚了幾天回你。我的格式沒有參考過r3det的,想問你是參考哪邊的格式呢?

我是用labelImg2來幫助我Label的。Label完的格式會跟UCAS的不太一樣,不過只需要從不同格式中,取出bounding box的 (x, y, w, h, angle),分別是中心點、寬、長和角度就可以進行訓練。你可以在load.py裡面做修改。

下方是我針對使用labelImg2做的Label產生的格式做的修改供你參考

yingkunwu avatar Nov 03 '21 08:11 yingkunwu

class ListDataset(Dataset):
    def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=False):
        self.img_files = list_path

        self.label_files = [
            path.replace("images", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
            for path in self.img_files
        ]
        self.img_size = img_size
        #self.labels = labels
        self.max_objects = 100
        self.augment = augment
        self.multiscale = multiscale
        self.normalized_labels = normalized_labels
        self.min_size = self.img_size - 3 * 32
        self.max_size = self.img_size + 3 * 32
        self.batch_count = 0

    def __getitem__(self, index):

        # ---------
        #  Image
        # ---------
        img_path = self.img_files[index]

        # Extract image as PyTorch tensor
        img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))

        # Handle images with less than three channels
        if len(img.shape) != 3:
            img = img.unsqueeze(0)
            img = img.expand((3, img.shape[1:]))

        _, h, w = img.shape
        h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)

        # Pad to square resolution
        if self.augment:
            if np.random.random() < 0.25:
                img = gaussian_noise(img, 0.0, np.random.random())
            if np.random.random() < 0.25:
                img = hsv(img)
        img, pad = pad_to_square(img, 0)

        _, padded_h, padded_w = img.shape

        # ---------
        #  Label
        # ---------
        label_path = self.label_files[index % len(self.img_files)].rstrip()

        if os.path.exists(label_path):
            boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 6))
            num_targets = len(boxes)

            x, y, w, h, theta, label = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4], boxes[:, 5]
            temp_theta = []
            for t in theta:
                if t > np.pi / 2:
                    t = t - np.pi
                elif t <= -(np.pi / 2):
                    t = t + np.pi
                temp_theta.append(t)

            theta = torch.stack(temp_theta)
            assert (-np.pi / 2 < theta).all() or (theta <= np.pi / 2).all()

            for i in range(num_targets):
                if w[i] < h[i]:
                    temp1, temp2 = h[i].clone(), w[i].clone()
                    w[i], h[i] = temp1, temp2
                    if theta[i] > 0:
                        theta[i] = theta[i] - np.pi / 2
                    else:
                        theta[i] = theta[i] + np.pi / 2
            assert (-np.pi / 2 < theta).all() or (theta <= np.pi / 2).all()

            # Extract coordinates for unpadded + unscaled image
            x1 = w_factor * (x - w / 2)
            y1 = h_factor * (y - h / 2)
            x2 = w_factor * (x + w / 2)
            y2 = h_factor * (y + h / 2)

            # Adjust for added padding
            x1 += pad[0]
            y1 += pad[2]
            x2 += pad[1]
            y2 += pad[3]

            # Returns (x, y, w, h)
            x = ((x1 + x2) / 2) / padded_w
            y = ((y1 + y2) / 2) / padded_h
            w *= w_factor / padded_w
            h *= h_factor / padded_h

            targets = torch.zeros((len(boxes), 7))
            targets[:, 1] = label
            targets[:, 2] = x
            targets[:, 3] = y
            targets[:, 4] = w
            targets[:, 5] = h
            targets[:, 6] = theta
            assert (0 <= x).all() or (x <= 1).all()
            assert (0 <= y).all() or (y <= 1).all()
            assert (0 <= w).all() or (w <= 1).all()
            assert (0 <= h).all() or (h <= 1).all()
        else:
            assert False
            targets = torch.zeros((1, 7))
            targets[:, 1] = -1
            return img_path, img, targets

        # Apply augmentations
        if self.augment:
            if np.random.random() < 0.25:
                img, targets = rotate(img, targets)
            if np.random.random() < 0.5:
                img, targets = horisontal_flip(img, targets)
            if np.random.random() < 0.5:
                img, targets = vertical_flip(img, targets)
        return img_path, img, targets

    def collate_fn(self, batch):
        paths, imgs, targets = list(zip(*batch))
        # Remove empty placeholder targets
        targets = [boxes for boxes in targets if boxes is not None]
        # Add sample index to targets
        for i, boxes in enumerate(targets):
            boxes[:, 0] = i
        targets = torch.cat(targets, 0)
        # Selects new image size every tenth batch
        if self.multiscale and self.batch_count % 10 == 0:
            self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
        # Resize images to input shape
        imgs = torch.stack([resize(img, self.img_size) for img in imgs])
        self.batch_count += 1
        return paths, imgs, targets

    def __len__(self):
        return len(self.img_files)


def split_data(data_dir, img_size, batch_size=1, shuffle=True, augment=True, multiscale=False):
    files = sorted(glob.glob(data_dir + "/*.jpg"))
    train_dataset = ListDataset(files, img_size=img_size, augment=augment, multiscale=multiscale)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle,
                                                   pin_memory=True, collate_fn=train_dataset.collate_fn)

    return train_dataset, train_dataloader

yingkunwu avatar Nov 03 '21 08:11 yingkunwu

收到,大佬! 我是也是lb2标注,然后参照dota数据集格式将xml转化为txt的,由衷感谢您的回复,这将为我提供很大的帮助!!!

PJPomPom avatar Nov 07 '21 09:11 PJPomPom

您好,可以发一下xml转化txt的程序文件嘛?谢谢

zjs210 avatar Nov 30 '21 13:11 zjs210

可以参考这个作者的链接 https://github.com/ChenCongGit/RoLabelImg_Transform

PJPomPom avatar Dec 01 '21 13:12 PJPomPom

可以参考这个作者的链接 https://github.com/ChenCongGit/RoLabelImg_Transform

感谢

zjs210 avatar Dec 01 '21 13:12 zjs210

可以参考这个作者的链接 https://github.com/ChenCongGit/RoLabelImg_Transform

大佬可以给个联系方式嘛?还有问题想请教一下

zjs210 avatar Dec 07 '21 12:12 zjs210

可以参考这个作者的链接 https://github.com/ChenCongGit/RoLabelImg_Transform

大佬可以给个联系方式嘛?还有问题想请教一下

我也没跑通呢,我可以加你一起研究

PJPomPom avatar Dec 07 '21 15:12 PJPomPom

可以参考这个作者的链接 https://github.com/ChenCongGit/RoLabelImg_Transform

大佬可以给个联系方式嘛?还有问题想请教一下

我也没跑通呢,我可以加你一起研究

qq:171536269

zjs210 avatar Dec 08 '21 01:12 zjs210