ipex-llm icon indicating copy to clipboard operation
ipex-llm copied to clipboard

callbacks with PyTorchPySparkEstimator aren't working as expected

Open timsetsfire opened this issue 2 years ago • 6 comments

I created my own callback to do some tracking with Weights and Biases.

class WandBCallback(Callback):
  import wandb
  def __init__(self, project, data_creator):
    self.run = None
    self.model = None
    self.params = None
    self.trainer = None 
    self.project = project
    self.data_creator = data_creator

  def on_batch_begin(self, batch):
    pass

  def on_batch_end(self, batch):
    pass
  
  def on_epoch_begin(self, epoch):
    pass

  def on_epoch_end(self, epoch):
    out_metrics = self.trainer.validate(self.data_creator)
    self.run.log({"epoch": epoch, "metrics": out_metrics })

  def on_train_begin(self):
    self.run = wandb.init(project = self.project, group = "distributed-training")
    pass

  def on_train_end(self):
    self.run.finish()
    pass

  def set_model(self, model):
    self.model = model 

  def set_param(self, param):
    self.params = param

  def set_trainer(self, trainer):
    self.trainer = trainer

If i run this code

est = PyTorchPySparkEstimator(
     model_creator = model_creator, 
     optimizer_creator = optimizer_creator, 
     loss_creator= criterion_creator, workers_per_node=1, model_dir = model_directory)
res = est.fit(train_data_creator, batch_size = 60000, epochs=2)
res

i do indeed see improvements from one epoch to the next,

Out[156]: [{'num_samples': 60000,
  'epoch': 1,
  'batch_count': 6,
  'train_loss': 1.851143717765808,
  'last_train_loss': 1.2832428216934204},
 {'num_samples': 60000,
  'epoch': 2,
  'batch_count': 6,
  'train_loss': 0.6781290769577026,
  'last_train_loss': 0.4633287191390991}]

where as this code

est = PyTorchPySparkEstimator(
     model_creator = model_creator, 
     optimizer_creator = optimizer_creator, 
     loss_creator= criterion_creator, workers_per_node=5, model_dir = model_directory)
res = est.fit(train_data_creator, batch_size = 10000, epochs=2, callbacks = [WandBCallback("bigdl-test-v3", test_data_creator)])
res

Results in

Out[157]: [{'num_samples': 60000,
  'epoch': 1,
  'batch_count': 6,
  'train_loss': 1.8442798852920532,
  'last_train_loss': 1.2671557664871216},
 {'num_samples': 60000,
  'epoch': 1,
  'batch_count': 6,
  'train_loss': 1.8442798852920532,
  'last_train_loss': 1.2671557664871216}]

In my callback, it appears that only self.trainer is being set, and that invoking self.trainer.validate within on_epoch_end (or anywhere for that matter) does something curious behind the scenes and I can't figure out what - almost like it resets the state of the model to it's initialization.

Apologies if I'm doing something, but I didn't find any details laying out usage. Any guidance on how to set this up properly would be greatly appreciated.

timsetsfire avatar May 06 '22 15:05 timsetsfire

I tried creating the estimator via

est = Estimator.from_torch(
  model = model_creator, 
  optimizer = optimizer_creator, 
  loss = criterion_creator,  
  backend  = "spark", 
  model_dir = "file:///tmp/bigdl-model2")
res = est.fit(train_data_creator, batch_size = 10000, epochs=2, callbacks = [WandBCallback("bigdl-test-v3", test_data_creator)])

i experiece the same "weird" behavior

  'epoch': 1,
  'batch_count': 6,
  'train_loss': 1.843607485294342,
  'last_train_loss': 1.2698042392730713},
 {'num_samples': 60000,
  'epoch': 1,
  'batch_count': 6,
  'train_loss': 1.843607485294342,
  'last_train_loss': 1.2698042392730713}]

timsetsfire avatar May 06 '22 16:05 timsetsfire

Hey, sorry for the late reply! I have reviewed your question and the code you provided.

  1. Reason. This happens because the self.trainer.validate was invoked in the on_epoch_end function. In PyTorchPySparkEstimator, that means trainer.validate (which is returned by callback.on_epoch_end) will be passed to loop and evaluate for epoch=2 times with an un-updated model, the trainer will also be reset in each loop, getting the validate result in callbacks is not a good idea.

  2. Solution. Sorry, I am not clear about what kind of results you wanna get, would you mind providing more details?

sgwhat avatar May 07 '22 15:05 sgwhat

No apologies necessary. I apologize I was not entirely clear. The purpose of my call back would be to evaluate a validation set at the end of an epoch and relay the results the weights and biases for tracking.

Do you have recommendation on how best to achieve the desired behavior ?

timsetsfire avatar May 07 '22 15:05 timsetsfire

I am sorry that currently there might not be an easy way to achieve this. As a work-around, you could manually call fit and evaluate every epoch.

est = Estimator.from_torch(
  model = model_creator, 
  optimizer = optimizer_creator, 
  loss = criterion_creator,  
  backend  = "spark", 
  model_dir = "file:///tmp/bigdl-model2")

epochs = 2
for i in range(epochs):
     train_stats = est.fit(train_data_creator, batch_size = 10000, epochs=1)
     val_stats= est.evaluate(test_data_creator, batch_size=test_batch_size,)
    
     # get the pytorch model for weights and biases
     model = est.get_model()

Note that we will add support for validating every n epochs during training in fit, but this feature may take a while to be fully supported. I will update the status in the issue.

shanyu-sys avatar May 09 '22 09:05 shanyu-sys

We have integrated a callback for weights and bias as WandbLoggerCallback (code), which automatically uses Weights and biases to log metric results at the end of each epoch. Validation metrics results will be automatedly logged if user passes validation_data in fit. You could choose whether to watch model gradients via watch_model.

  1. Run wandb login from the command line.
  2. In python script,
from bigdl.orca.learn.pytorch.callbacks.wandb import WandbLoggerCallback
wandb_callback = WandbLoggerCallback(project="cifar", watch_model=True)

estimator.fit(train_data_creator, batch_size = 10000, epochs=2, validate_data=test_data_creator, callbacks=[wandb_callback])

shanyu-sys avatar Jul 04 '22 08:07 shanyu-sys

Wow! I can't wait to check this out. Thank you for your attention on this issue!!

timsetsfire avatar Jul 11 '22 13:07 timsetsfire