vision icon indicating copy to clipboard operation
vision copied to clipboard

How to tell if Faster RCNN Detection model is overfitting

Open 1andDone opened this issue 2 years ago • 2 comments

I'm confused as to how I can tell if the Faster RCNN Detection model I'm training is overfitting or not given that the validation loss is not computed in the evaluate function seen here and below.

Any help would be greatly appreciated.

@torch.inference_mode()
def evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = "Test:"

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)

    for images, targets in metric_logger.log_every(data_loader, 100, header):
        images = list(img.to(device) for img in images)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(images)

        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time

        res = {target["image_id"]: output for target, output in zip(targets, outputs)}
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator

1andDone avatar Oct 27 '23 00:10 1andDone

Hi @1andDone ,

That function is called with data_loader_test:

https://github.com/pytorch/vision/blob/3fb88b3ef1ee8107df74ca776cb57931fe3e9e1e/references/detection/train.py#L325

which correspond to the "val" part of Coco (the naming might be a bit unfortunate):

https://github.com/pytorch/vision/blob/3fb88b3ef1ee8107df74ca776cb57931fe3e9e1e/references/detection/train.py#L44

NicolasHug avatar Oct 27 '23 09:10 NicolasHug

So to check overfitting, it is sufficient to check if the mAP is decreasing and don't look at the validation loss at all?

maltesilber avatar Jan 16 '24 14:01 maltesilber