tnt
tnt copied to clipboard
How to check `is_last_batch` in torchtnt==0.1.0?
🐛 Describe the bug
In the train_step
method of a Callback class:
def train_step(self, state: State, data: TrainBatch) -> None:
global_step = state.train_state.progress.num_steps_completed
is_last_batch = state.train_state.is_last_batch # error
Error:
AttributeError:'PhaseState' object has no attribute 'is_last_batch'
train_state._step_output = train_unit.train_step(state, step_input)
Versions
Hi @daniellepintz
Why is is_last_batch
removed in #367 ?
This wont pass the test defined in https://github.com/pytorch/tnt/blob/9b3b7b1a3c0cfa8354bd459fe84a46a03b2754f5/tests/framework/test_auto_unit.py#L901
What is the correct way to check is_last_batch
? Thank you in advance!
The env:
torchtnt==0.1.0
torcheval==0.0.6
torchsnapshot==0.1.0
### Tasks