pytorch_ema
pytorch_ema copied to clipboard
is model.eval() necessary within or outside context manager
Thank for your this awesome lib.. I had a little question: do we need to use ,model.eval()
inside the with
context manager?
So, which of these are legit please?
with ema.average_parameters():
model.eval()
logits = model(x_val)
loss = F.cross_entropy(logits, y_val)
print(loss.item())
or is this the thing to do:
model.eval()
with ema.average_parameters():
logits = model(x_val)
loss = F.cross_entropy(logits, y_val)
print(loss.item())
Finally, does this set up work in distributed mode (is there any examples for this)?
Thank you.
model.eval()
doesn't do anything except set a boolean flag on model
, so whether you do it inside or outside the context manager shouldn't matter. (See https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.eval, def train(
.)
Re distributed, if you do distributed model with synced gradients, it should be fine as long as you call ema.update()
after the sync. If the model is synced before feeding it to EMA, then the EMAs will all have the same averaged state.