ignite
ignite copied to clipboard
Possible speed improvement
Description
I recently noticed possible speed improvement when porting segmentation example to code-generator.
According to the docs, examples, helper functions, and tests, we are calling model.train()
or model.eval()
inside the process function given to the Engine
.
And as per this line, I think model.train()
or model.eval()
gets called every iteration which I think not necessary.
https://github.com/pytorch/ignite/blob/0d21a0bfe0bb03980055f9748b827cf09a6faee3/ignite/engine/engine.py#L853
To my knowledge, .train()
or .eval()
is only needed when changing the context (e.g. training or evaluation).
And in pytorch example, model.train()
or model.eval()
gets called only once when changing context (training -> evaluation, or vice versa)
Possible solution
I think we can address in the ongoing High level API development or address in the refactor of Engine
design.
Workaround
For now, we can workaround by calling model.train()
before every training epoch and model.eval()
before every evaluation epoch.
Thanks!
Oh, I think it's a very nice remark if model.train()
is actually taking time. Assuming it takes time, you're right it, avoiding this in the Higher Level API is important, and I think we should find a way to refactor the engine
.
Thanks!
@ydcjeff Thank you for pointing out this. Did you measured the overhead calling .train()
at each iteration ? I would expect no overhead from PyTorch...
Thanks @ydcjeff for the question!
The overhead of calling module.train()
is to set module.training
flag to the module and its children:
https://github.com/pytorch/pytorch/blob/28840b9a447e820bf2e66f6b6cdc880c12d077b7/torch/nn/modules/module.py#L1637-L1642
As @sdesrozis me neither I do not expect a large overhead on that. I agree that in a higher abstraction we could try to deal with that. However, I think it could be a bit difficult to apply in general as user could potentially attach a handler on iterations, modify the state of the flag and does modify it back to continue the training.
We could either explicitly mention in high level API, we have already called .train()
or .eval()
or let users call themselves.