Stop training Callback
💡 Your Question
Hi! I you like to create a custom Callback to stop the training process when a flag is set to True. To do so, I update the context.stop_training, but the training doesn't stop.
context.stop_training is correctly updated inside my training scripts. However, it is not updated in the sg_trainer.py. How do I update the PhaseContext during training?
Many thanks,
Please find a part of my code below:
# Custom Callback to stop training
class StopTrainingCallback(Callback):
def __init__(self, stop_flag):
super().__init__()
self.stop_requested = stop_flag
@multi_process_safe
def on_train_batch_end(self, context: PhaseContext) -> None:
if self.stop_requested[0]:
print("The training should stop...")
context.update_context(stop_training=True)
# My training process
class TrainYoloNas():
def __init__(self, name, param):
...
self.trainer = None
self.stop_requested = [False]
self.train_params = {
'silent_mode': False,
"average_best_models":True,
"phase_callbacks": [StopTrainingCallback(self.stop_requested)],
}
def run(self):
# Load model
model = models.get(
self.model_name,
num_classes=dataset_info['nc'],
pretrained_weights="coco"
)
# Trainer initialization
self.trainer = Trainer(experiment_name=experiment_name, ckpt_root_dir=param.cfg["output_folder"])
# Train
self.trainer.train(
model=model,
training_params=self.train_params,
train_loader=train_data,
valid_loader=val_data
)
# Stop training flag
def stop(self):
super().stop()
print("Stopping requested...")
self.stop_requested[0] = True
Versions
PyTorch version: 1.9.0+cu111 Is debug build: False CUDA used to build PyTorch: 11.1 ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Home GCC version: Could not collect Clang version: Could not collect CMake version: version 3.26.1 Libc version: N/A
Python version: 3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.22621-SP0 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060 Laptop GPU Nvidia driver version: 531.61 cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\cudnn_ops_train64_8.dll HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture=9 CurrentClockSpeed=2496 DeviceID=CPU0 Family=205 L2CacheSize=1536 L2CacheSpeed= Manufacturer=GenuineIntel MaxClockSpeed=2496 Name=Intel(R) Core(TM) i5-10500H CPU @ 2.50GHz ProcessorType=3 Revision=
Versions of relevant libraries: [pip3] flake8==6.0.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.23.0 [pip3] pytorch-quantization==2.1.2 [pip3] torch==1.9.0+cu111 [pip3] torchmetrics==0.8.0 [pip3] torchreid==1.4.0 [pip3] torchvision==0.10.0+cu111 [conda] Could not collect
Hi @allankouidri. Your code looks Ok. I think we have a bug on Trainer (the context inside the epoch is not the same as the context of the general loop). try using on_train_loader_end instead on on_train_batch_end. this will not stop immediately but will work at the end of the epoch.
a bug was opened
Hi @ofrimasad, I confirm using on_train_loader_end is ending the training at the end of the epoch. Many thanks.