How to tell if Faster RCNN Detection model is overfitting
I'm confused as to how I can tell if the Faster RCNN Detection model I'm training is overfitting or not given that the validation loss is not computed in the evaluate function seen here and below.
Any help would be greatly appreciated.
@torch.inference_mode()
def evaluate(model, data_loader, device):
n_threads = torch.get_num_threads()
# FIXME remove this and make paste_masks_in_image run on the GPU
torch.set_num_threads(1)
cpu_device = torch.device("cpu")
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
coco = get_coco_api_from_dataset(data_loader.dataset)
iou_types = _get_iou_types(model)
coco_evaluator = CocoEvaluator(coco, iou_types)
for images, targets in metric_logger.log_every(data_loader, 100, header):
images = list(img.to(device) for img in images)
if torch.cuda.is_available():
torch.cuda.synchronize()
model_time = time.time()
outputs = model(images)
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
model_time = time.time() - model_time
res = {target["image_id"]: output for target, output in zip(targets, outputs)}
evaluator_time = time.time()
coco_evaluator.update(res)
evaluator_time = time.time() - evaluator_time
metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
coco_evaluator.accumulate()
coco_evaluator.summarize()
torch.set_num_threads(n_threads)
return coco_evaluator
Hi @1andDone ,
That function is called with data_loader_test:
https://github.com/pytorch/vision/blob/3fb88b3ef1ee8107df74ca776cb57931fe3e9e1e/references/detection/train.py#L325
which correspond to the "val" part of Coco (the naming might be a bit unfortunate):
https://github.com/pytorch/vision/blob/3fb88b3ef1ee8107df74ca776cb57931fe3e9e1e/references/detection/train.py#L44
So to check overfitting, it is sufficient to check if the mAP is decreasing and don't look at the validation loss at all?