rf-detr icon indicating copy to clipboard operation
rf-detr copied to clipboard

Add YOLO dataset format support

Open mario-dg opened this issue 9 months ago β€’ 36 comments

This pull request introduces functionality to convert datasets from YOLO format to COCO format within the rfdetr package. The key changes include adding utility functions for this conversion and integrating these functions into the training workflow.

This PR addresses and fixes #69.

New functionality:

  • rfdetr/util/coco_to_yolo.py: Added utility functions is_valid_yolo_format and convert_to_coco to check the format of YOLO datasets and convert them to COCO format.

Integration into training workflow:

  • rfdetr/detr.py: Imported the new utility functions and added logic to convert datasets from YOLO to COCO format during training if they are detected to be in YOLO format. [1] [2]
def train_from_config(self, config: TrainConfig, **kwargs):
    if is_valid_yolo_format(config.dataset_dir):
        config.dataset_dir = convert_to_coco(config.dataset_dir)

    with open(
        os.path.join(config.dataset_dir, "train", "_annotations.coco.json"), "r"
    ) as f:
        anns = json.load(f)
        num_classes = len(anns["categories"])

After a successfull conversion, the dataset_dir config will be overwritten to ensure seamless training afterwards.

Type of change

  • [ x ] New feature (non-breaking change which adds functionality)

How has this change been tested, please provide a testcase or example of how you tested the change?

Will create a collab later. Locally the conversion was successfull, but I haven't tested this changed in a full training workflow yet.

Any specific deployment considerations

For example, documentation changes, usability, usage/costs, secrets, etc.

Docs

Will be updated later.

  • [ ] Docs updated? What were the changes:

mario-dg avatar Mar 28 '25 13:03 mario-dg

Hi πŸ‘‹πŸ» @mario-dg, thanks a lot for opening this PR!

It’s a solid step toward making training smoother for users working with YOLO datasets, and I appreciate the effort you’ve put into this. That said, I do have a couple of concerns I'd love your thoughts on:

I'm a bit worried about the performance of the YOLO-to-COCO conversion, especially for larger datasets. Have you tried running a full training pipeline on something like TFT-ID or SKU 110k? I’m curious if we’d observe any memory spikes or long conversion times.

It seems like we’re not caching or checking whether the dataset has already been converted. That could lead to unnecessary reprocessing on each run, which might become costly over time.

Given that, I wonder if it might make more sense in the long run to introduce a lightweight native loader for YOLO datasets instead of converting everything to COCO on the fly.

Looking forward to hearing your thoughts! @probicheaux @isaacrob-roboflow

SkalskiP avatar Mar 28 '25 14:03 SkalskiP

@mario-dg what about:

# datasets.__init__.py

def build_dataset(image_set, args, resolution):
    if args.dataset_file == 'coco':
        return build_coco(image_set, args, resolution)
    if args.dataset_file == 'o365':
        return build_o365(image_set, args, resolution)
    if args.dataset_file == 'roboflow':
        return build_roboflow(image_set, args, resolution)
    if args.dataset_file == 'yolo':
        return build_yolo(image_set, args, resolution)
    raise ValueError(f'dataset {args.dataset_file} not supported')
# datasets.yolo.py

def build_yolo(mage_set, args, resolution):
    ...
    # Mimic build_coco in datasets.coco.py


# Load the YOLO annotations per usual, but output the results as COCODetection does
from torch.utils.data import Dataset


class YOLODetection(Dataset):
    def __init__(self, img_folder, ann_folder, transforms=None):
        """
        YOLO detection dataset.
        
        Args:
            img_folder: Path to the folder containing images
            ann_folder: Path to the folder containing YOLO annotation .txt files
            transforms: Optional transforms to be applied to images and targets
        """
        super().__init__()
        self.img_folder = img_folder
        self.ann_folder = ann_folder
        self._transforms = transforms
        
        # Get all image files with corresponding annotation files
        self.images = []
        self.annotations = []
        
        # Get supported image extensions
        img_extensions = ['.jpg', '.jpeg', '.png', '.bmp']
        
        # Find all valid image files that have corresponding annotation files
        for filename in os.listdir(img_folder):
            name, ext = os.path.splitext(filename)
            if ext.lower() in img_extensions:
                ann_path = os.path.join(ann_folder, name + '.txt')
                if os.path.exists(ann_path):
                    self.images.append(os.path.join(img_folder, filename))
                    self.annotations.append(ann_path)

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Load image
        img_path = self.images[idx]
        img = Image.open(img_path).convert('RGB')
        img_width, img_height = img.size
        
        # Load annotations
        ann_path = self.annotations[idx]
        boxes = []
        labels = []
        
        # Read YOLO format annotations
        with open(ann_path, 'r') as f:
            for line in f.readlines():
                if line.strip():
                    values = line.strip().split()
                    if len(values) >= 5:  # class, coordinates
                        class_id = int(values[0])
                        
                        if len(values) == 5:  # Standard bounding box format
                            # Convert normalized YOLO format to absolute coordinates
                            x_center = float(values[1]) * img_width
                            y_center = float(values[2]) * img_height
                            width = float(values[3]) * img_width
                            height = float(values[4]) * img_height
                            
                            # Convert from center coordinates to top-left corner
                            x_min = x_center - (width / 2)
                            y_min = y_center - (height / 2)
                            
                            boxes.append([x_min, y_min, width, height])
                            labels.append(class_id)
                        
                        elif len(values) > 5:  # Polygon format
                            # Parse polygon coordinates
                            polygon_points = []
                            for i in range(1, len(values), 2):
                                if i + 1 < len(values):
                                    # Convert normalized coordinates to absolute
                                    x = float(values[i]) * img_width
                                    y = float(values[i + 1]) * img_height
                                    polygon_points.append((x, y))
                            
                            # Convert polygon to bounding box
                            if polygon_points:
                                # Do conversion using supervision.utils
        
        # Create target dictionary
        target = {
            'boxes': torch.tensor(boxes, dtype=torch.float32),
            'labels': torch.tensor(labels, dtype=torch.int64),
            'image_id': torch.tensor([idx]),
            'orig_size': torch.tensor([img_height, img_width]),
        }
        
        # Apply transforms if available
        if self._transforms is not None:
            img, target = self._transforms(img, target)
            
        return img, target

# Do something with transforms
...

Jordan-Pierce avatar Mar 28 '25 15:03 Jordan-Pierce

Yes, I agree. This is a very naive approach. I haven't tested my implementation extensively yet, but from previous experiments I know that the supervision conversion can take a lot of time for large datasets. Hence, a new dataloader approach, that @Jordan-Pierce already mentioned in the issue might be needed.

mario-dg avatar Mar 28 '25 15:03 mario-dg

in general I'm in favor of a native data loader as opposed to conversion. the downside is I don't want to build custom loaders for n+1 dataset formats :) do y'all think it's likely people will want native support for more than just yolo and coco formats? if no I am pro data loader

isaacrob-roboflow avatar Mar 28 '25 18:03 isaacrob-roboflow

in general I'm in favor of a native data loader as opposed to conversion. the downside is I don't want to build custom loaders for n+1 dataset formats :) do y'all think it's likely people will want native support for more than just yolo and coco formats? if no I am pro data loader

I think lots of people (not all, obviously) who might want to use RF-DETR are people who might also use libraries that use YOLO-formatted datasets (πŸ™‹β€β™‚οΈ). Other dataset formats that are likely contenders are what, PASCAL-VOC?

But, given the demographic of peoples who use RF and also libraries that use YOLO-formatted datasets, I feel like these two formats would cover a lot of the people.

Jordan-Pierce avatar Mar 28 '25 18:03 Jordan-Pierce

I've worked on a YOLO format data loader for a while now. The loader itself seems to work, when testing it isolated with below script. But the training is still failing.

Collab available now: https://colab.research.google.com/drive/143icsDIfvgOmtzfzEDLEeh4wMQu2g181?usp=sharing

#!/usr/bin/env python3
"""
Test script for the YOLO dataloader in RF-DETR.
This script checks if the YOLO dataloader can correctly read a YOLO format dataset.
"""

import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import torch
import random
from torchvision.transforms import functional as F
from matplotlib.patches import Rectangle

from rfdetr.datasets.yolo import build_yolo


def parse_args():
    parser = argparse.ArgumentParser(description="Test YOLO dataloader")
    parser.add_argument("--dataset-dir", type=str, required=True, help="Path to YOLO dataset")
    parser.add_argument("--image-set", type=str, default="train", help="Image set (train, val, test)")
    parser.add_argument("--resolution", type=int, default=640, help="Image resolution")
    parser.add_argument("--n-samples", type=int, default=5, help="Number of samples to display")
    parser.add_argument("--random", action="store_true", help="Randomly select samples instead of the first n")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    return parser.parse_args()


class Args:
    """Dummy class to hold args for the dataloader"""
    def __init__(self, dataset_dir):
        self.dataset_dir = dataset_dir
        self.multi_scale = False
        self.expanded_scales = False
        self.square_resize_div_64 = False


def plot_sample(img, target, idx, class_names, args):
    """Plot a sample with bounding boxes"""
    # Convert tensor to numpy for visualization
    img_np = img.permute(1, 2, 0).numpy()
    
    # Denormalize the image
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_np = img_np * std + mean
    img_np = np.clip(img_np, 0, 1)
    
    # Create the figure
    fig, ax = plt.subplots(1, figsize=(10, 10))
    ax.imshow(img_np)
    
    # Get boxes and labels
    boxes = target["boxes"].numpy()
    labels = target["labels"].numpy()
    
    # Get image dimensions for denormalizing coordinates
    h, w = img_np.shape[0:2]
    
    # Plot each box
    for box, label in zip(boxes, labels):
        # RF-DETR stores boxes in normalized [centerX, centerY, width, height] format
        # We need to convert to absolute pixel coordinates for visualization
        cx, cy, bw, bh = box
        
        # Convert center coordinates to top-left
        x1 = (cx - bw/2) * w
        y1 = (cy - bh/2) * h
        width = bw * w
        height = bh * h
        
        # Print for debugging
        print(f"Box: {box}, Label: {label}")
        print(f"  Denormalized: x={x1:.1f}, y={y1:.1f}, w={width:.1f}, h={height:.1f}")
        
        # Create and add the rectangle
        rect = Rectangle((x1, y1), width, height, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        
        # If class_names is available, use the class name instead of the numeric label
        class_label = class_names[label-1] if class_names and label-1 < len(class_names) else f"Class: {label}"
        ax.text(x1, y1, class_label, color='white', fontsize=12, 
                backgroundcolor='red', verticalalignment='top')
    
    ax.set_title(f"Sample {idx} - {len(boxes)} objects detected")
    plt.axis('off')
    plt.tight_layout()
    
    # Create output directory if it doesn't exist
    os.makedirs("test_output", exist_ok=True)
    plt.savefig(f"test_output/sample_{idx}.png")
    plt.close()


def main(args):
    print(f"Testing YOLO dataloader with dataset: {args.dataset_dir}")
    
    # Set random seed for reproducibility
    if args.random:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        print(f"Using random seed: {args.seed}")
    
    # Initialize Args for dataloader
    loader_args = Args(args.dataset_dir)
    
    try:
        # Build dataset
        print(f"Building dataset with image_set={args.image_set}, resolution={args.resolution}")
        dataset = build_yolo(args.image_set, loader_args, args.resolution)
        
        print(f"Dataset size: {len(dataset)} samples")
        print(f"Class names: {dataset.class_names}")
        
        # Verify dataset configurations
        print(f"Class mapping (YOLO class ID -> COCO class ID): {dataset.class_to_coco_id}")
        
        # Verify COCO API functionality
        coco_api = dataset.coco
        print(f"Total annotations: {len(coco_api.anns)}")
        print(f"Total categories: {len(coco_api.cats)}")
        print(f"Total images: {len(coco_api.imgs)}")
        
        # Print actual category IDs for debugging
        print(f"Category IDs in the dataset: {list(coco_api.cats.keys())}")
        
        # Select sample indices
        n_samples = min(args.n_samples, len(dataset))
        if args.random:
            sample_indices = random.sample(range(len(dataset)), n_samples)
            print(f"Randomly selected samples: {sample_indices}")
        else:
            sample_indices = list(range(n_samples))
            print(f"Using first {n_samples} samples")
        
        # Display selected samples
        print(f"Displaying {n_samples} samples...")
        
        # Check for potential issues in samples
        for i, idx in enumerate(sample_indices):
            try:
                img, target = dataset[idx]
                print(f"Sample {i} (dataset index {idx}):")
                print(f"  Image shape: {img.shape}")
                print(f"  Boxes: {target['boxes'].shape}")
                print(f"  Labels: {target['labels']}")
                
                # Validate labels - check if any label is out of range
                if len(target['labels']) > 0:
                    max_label = target['labels'].max().item()
                    min_label = target['labels'].min().item()
                    num_classes = len(dataset.class_names) + 1  # +1 for background class
                    
                    if max_label >= num_classes or min_label <= 0:
                        print(f"  WARNING: Invalid label range: min={min_label}, max={max_label}, valid range=[1, {num_classes-1}]")
                        
                    # Debug output for label values
                    label_counts = {}
                    for label in target['labels']:
                        l = label.item()
                        if l not in label_counts:
                            label_counts[l] = 0
                        label_counts[l] += 1
                    print(f"  Label distribution: {label_counts}")
                
                # Plot sample
                plot_sample(img, target, i, dataset.class_names, args)
            except Exception as e:
                print(f"Error processing sample {idx}: {str(e)}")
                import traceback
                traceback.print_exc()
        
        print(f"Sample images saved to test_output/ directory.")
    except Exception as e:
        print(f"Error testing YOLO dataloader: {str(e)}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    args = parse_args()
    main(args) 

mario-dg avatar Mar 29 '25 12:03 mario-dg

@mario-dg I'm blown away by the progress in this PR! Have you managed to solve it?

I agree with @Jordan-Pierce. YOLO format is must have in project that is here to compete with YOLO models.

SkalskiP avatar Mar 29 '25 17:03 SkalskiP

Nice job @mario-dg !

Jordan-Pierce avatar Mar 29 '25 17:03 Jordan-Pierce

No not yet, trying to solve it tomorrow. In the meantime, maybe @Jordan-Pierce Has an idea? The main parts of the data loader are similar to his initial idea.

mario-dg avatar Mar 29 '25 20:03 mario-dg

Ok, I am getting there. The first test training run went through, locally and in a Collab πŸš€ I will try to train a model on a larger dataset, but I am limited by the free Google Collab runtime. The datasets used can be found in this collab (same as previous). They where just downloaded in different formats from Roboflow universe. https://colab.research.google.com/drive/143icsDIfvgOmtzfzEDLEeh4wMQu2g181?usp=sharing

Small Dataset in YOLO format

metrics_plot_small_yolo

Small Dataset in COCO format

metrics_plot_small_coco

mario-dg avatar Mar 31 '25 08:03 mario-dg

Awesome @mario-dg! πŸ”₯ I'll run some tests myself and make code review.

SkalskiP avatar Mar 31 '25 11:03 SkalskiP

@mario-dg I made first round of code review. is there any chance you could start working on those changes today? also are you on our discord server?

SkalskiP avatar Mar 31 '25 11:03 SkalskiP

Yes, I will work on them right away! And yes I am, but have not been really active yet.

mario-dg avatar Mar 31 '25 11:03 mario-dg

Understood. Let me know once you'll have all the updates ;)

SkalskiP avatar Mar 31 '25 12:03 SkalskiP

Letting you know πŸ˜„

mario-dg avatar Mar 31 '25 12:03 mario-dg

@mario-dg I just added second batch of comments. We are moving in the right direction! Please let me know if you'll be able to take a look at if right away. πŸ™πŸ»

We are changing a lot of stuff since your last training tests. It would be awesome if you could test it out once you're done. Just to make sure we didn't brake anything.

SkalskiP avatar Mar 31 '25 13:03 SkalskiP

Did most of them, will be able to do the last remaining one tonight. Will also start a test training then πŸ˜„

mario-dg avatar Mar 31 '25 14:03 mario-dg

Thanks a lot @mario-dg ! πŸ™πŸ» We will have one more review round. I'm sorry to drag you through that review process, but I want us to get it right.

SkalskiP avatar Mar 31 '25 16:03 SkalskiP

No worries. I wanna make sure that we have a solid code base that everyone can work upon. So do as many rounds as you feel are necessary

mario-dg avatar Mar 31 '25 16:03 mario-dg

We are on the same page! πŸ”₯

SkalskiP avatar Mar 31 '25 16:03 SkalskiP

@SkalskiP, we're good to go. Implemented your feedback and ran the test again with the small dataset. Everything working again πŸš€

mario-dg avatar Mar 31 '25 18:03 mario-dg

Hi @mario-dg , everything is looking good and thanks for your contribution. I'm afraid another merged pr has introduced a merge conflict here, can you resolve it? Then we should be good to go.

probicheaux avatar Apr 01 '25 02:04 probicheaux

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Apr 01 '25 04:04 CLAassistant

@probicheaux fixed the merged conflict. Sorry for the force push, accidently committed with my work account.

mario-dg avatar Apr 01 '25 05:04 mario-dg

Any news?

mario-dg avatar Apr 04 '25 13:04 mario-dg

@probicheaux are you comfortable owning making this happen? assuming since you were engaged prior

isaacrob-roboflow avatar Apr 04 '25 17:04 isaacrob-roboflow

@mario-dg I drive it to finish line next week ;)

SkalskiP avatar Apr 04 '25 18:04 SkalskiP

@SkalskiP, anything I can/should improve or change?

mario-dg avatar Apr 04 '25 22:04 mario-dg

@SkalskiP, when can we expect updates?

mario-dg avatar Apr 15 '25 07:04 mario-dg

Any updates on this?

Jordan-Pierce avatar Apr 22 '25 18:04 Jordan-Pierce