ipex-llm
ipex-llm copied to clipboard
callbacks with PyTorchPySparkEstimator aren't working as expected
I created my own callback to do some tracking with Weights and Biases.
class WandBCallback(Callback):
import wandb
def __init__(self, project, data_creator):
self.run = None
self.model = None
self.params = None
self.trainer = None
self.project = project
self.data_creator = data_creator
def on_batch_begin(self, batch):
pass
def on_batch_end(self, batch):
pass
def on_epoch_begin(self, epoch):
pass
def on_epoch_end(self, epoch):
out_metrics = self.trainer.validate(self.data_creator)
self.run.log({"epoch": epoch, "metrics": out_metrics })
def on_train_begin(self):
self.run = wandb.init(project = self.project, group = "distributed-training")
pass
def on_train_end(self):
self.run.finish()
pass
def set_model(self, model):
self.model = model
def set_param(self, param):
self.params = param
def set_trainer(self, trainer):
self.trainer = trainer
If i run this code
est = PyTorchPySparkEstimator(
model_creator = model_creator,
optimizer_creator = optimizer_creator,
loss_creator= criterion_creator, workers_per_node=1, model_dir = model_directory)
res = est.fit(train_data_creator, batch_size = 60000, epochs=2)
res
i do indeed see improvements from one epoch to the next,
Out[156]: [{'num_samples': 60000,
'epoch': 1,
'batch_count': 6,
'train_loss': 1.851143717765808,
'last_train_loss': 1.2832428216934204},
{'num_samples': 60000,
'epoch': 2,
'batch_count': 6,
'train_loss': 0.6781290769577026,
'last_train_loss': 0.4633287191390991}]
where as this code
est = PyTorchPySparkEstimator(
model_creator = model_creator,
optimizer_creator = optimizer_creator,
loss_creator= criterion_creator, workers_per_node=5, model_dir = model_directory)
res = est.fit(train_data_creator, batch_size = 10000, epochs=2, callbacks = [WandBCallback("bigdl-test-v3", test_data_creator)])
res
Results in
Out[157]: [{'num_samples': 60000,
'epoch': 1,
'batch_count': 6,
'train_loss': 1.8442798852920532,
'last_train_loss': 1.2671557664871216},
{'num_samples': 60000,
'epoch': 1,
'batch_count': 6,
'train_loss': 1.8442798852920532,
'last_train_loss': 1.2671557664871216}]
In my callback, it appears that only self.trainer
is being set, and that invoking self.trainer.validate
within on_epoch_end
(or anywhere for that matter) does something curious behind the scenes and I can't figure out what - almost like it resets the state of the model to it's initialization.
Apologies if I'm doing something, but I didn't find any details laying out usage. Any guidance on how to set this up properly would be greatly appreciated.
I tried creating the estimator via
est = Estimator.from_torch(
model = model_creator,
optimizer = optimizer_creator,
loss = criterion_creator,
backend = "spark",
model_dir = "file:///tmp/bigdl-model2")
res = est.fit(train_data_creator, batch_size = 10000, epochs=2, callbacks = [WandBCallback("bigdl-test-v3", test_data_creator)])
i experiece the same "weird" behavior
'epoch': 1,
'batch_count': 6,
'train_loss': 1.843607485294342,
'last_train_loss': 1.2698042392730713},
{'num_samples': 60000,
'epoch': 1,
'batch_count': 6,
'train_loss': 1.843607485294342,
'last_train_loss': 1.2698042392730713}]
Hey, sorry for the late reply! I have reviewed your question and the code you provided.
-
Reason. This happens because the
self.trainer.validate
was invoked in theon_epoch_end
function. InPyTorchPySparkEstimator
, that meanstrainer.validate
(which is returned bycallback.on_epoch_end
) will be passed to loop and evaluate forepoch=2
times with an un-updated model, the trainer will also be reset in each loop, getting thevalidate
result in callbacks is not a good idea. -
Solution. Sorry, I am not clear about what kind of results you wanna get, would you mind providing more details?
No apologies necessary. I apologize I was not entirely clear. The purpose of my call back would be to evaluate a validation set at the end of an epoch and relay the results the weights and biases for tracking.
Do you have recommendation on how best to achieve the desired behavior ?
I am sorry that currently there might not be an easy way to achieve this. As a work-around, you could manually call fit
and evaluate
every epoch.
est = Estimator.from_torch(
model = model_creator,
optimizer = optimizer_creator,
loss = criterion_creator,
backend = "spark",
model_dir = "file:///tmp/bigdl-model2")
epochs = 2
for i in range(epochs):
train_stats = est.fit(train_data_creator, batch_size = 10000, epochs=1)
val_stats= est.evaluate(test_data_creator, batch_size=test_batch_size,)
# get the pytorch model for weights and biases
model = est.get_model()
Note that we will add support for validating every n epochs during training in fit
, but this feature may take a while to be fully supported. I will update the status in the issue.
We have integrated a callback for weights and bias as WandbLoggerCallback
(code), which automatically uses Weights and biases to log metric results at the end of each epoch. Validation metrics results will be automatedly logged if user passes validation_data
in fit
. You could choose whether to watch model gradients via watch_model
.
- Run
wandb login
from the command line. - In python script,
from bigdl.orca.learn.pytorch.callbacks.wandb import WandbLoggerCallback
wandb_callback = WandbLoggerCallback(project="cifar", watch_model=True)
estimator.fit(train_data_creator, batch_size = 10000, epochs=2, validate_data=test_data_creator, callbacks=[wandb_callback])
Wow! I can't wait to check this out. Thank you for your attention on this issue!!