super-gradients icon indicating copy to clipboard operation
super-gradients copied to clipboard

Stop training Callback

Open allankouidri opened this issue 2 years ago • 2 comments

💡 Your Question

Hi! I you like to create a custom Callback to stop the training process when a flag is set to True. To do so, I update the context.stop_training, but the training doesn't stop.

context.stop_training is correctly updated inside my training scripts. However, it is not updated in the sg_trainer.py. How do I update the PhaseContext during training?

Many thanks,

Please find a part of my code below:


# Custom Callback to stop training
class StopTrainingCallback(Callback):
    def __init__(self, stop_flag):
        super().__init__()
        self.stop_requested = stop_flag
        
    @multi_process_safe
    def on_train_batch_end(self, context: PhaseContext) -> None:
        if self.stop_requested[0]:
            print("The training should stop...")
            context.update_context(stop_training=True)


# My training process
class TrainYoloNas():
    def __init__(self, name, param):
        ...

        self.trainer = None
        self.stop_requested = [False]


        self.train_params = {
        'silent_mode': False,
        "average_best_models":True,
    
        "phase_callbacks": [StopTrainingCallback(self.stop_requested)],
    }

    def run(self):
        # Load model
        model = models.get(
            self.model_name, 
            num_classes=dataset_info['nc'], 
            pretrained_weights="coco"
        )

        # Trainer initialization
        self.trainer = Trainer(experiment_name=experiment_name, ckpt_root_dir=param.cfg["output_folder"])

        # Train
        self.trainer.train(
            model=model, 
            training_params=self.train_params, 
            train_loader=train_data, 
            valid_loader=val_data
        )

    # Stop training flag
    def stop(self):
        super().stop()
        print("Stopping requested...")
        self.stop_requested[0] = True

Versions

PyTorch version: 1.9.0+cu111 Is debug build: False CUDA used to build PyTorch: 11.1 ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 Home GCC version: Could not collect Clang version: Could not collect CMake version: version 3.26.1 Libc version: N/A

Python version: 3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.22621-SP0 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU Nvidia driver version: 531.61 cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\cudnn_ops_train64_8.dll HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture=9 CurrentClockSpeed=2496 DeviceID=CPU0 Family=205 L2CacheSize=1536 L2CacheSpeed= Manufacturer=GenuineIntel MaxClockSpeed=2496 Name=Intel(R) Core(TM) i5-10500H CPU @ 2.50GHz ProcessorType=3 Revision=

Versions of relevant libraries: [pip3] flake8==6.0.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.23.0 [pip3] pytorch-quantization==2.1.2 [pip3] torch==1.9.0+cu111 [pip3] torchmetrics==0.8.0 [pip3] torchreid==1.4.0 [pip3] torchvision==0.10.0+cu111 [conda] Could not collect

allankouidri avatar Jun 08 '23 15:06 allankouidri

Hi @allankouidri. Your code looks Ok. I think we have a bug on Trainer (the context inside the epoch is not the same as the context of the general loop). try using on_train_loader_end instead on on_train_batch_end. this will not stop immediately but will work at the end of the epoch.

a bug was opened

ofrimasad avatar Jun 14 '23 10:06 ofrimasad

Hi @ofrimasad, I confirm using on_train_loader_end is ending the training at the end of the epoch. Many thanks.

allankouidri avatar Jun 14 '23 13:06 allankouidri