tnt icon indicating copy to clipboard operation
tnt copied to clipboard

How to check `is_last_batch` in torchtnt==0.1.0?

Open yiminglin-ai opened this issue 1 year ago • 2 comments

🐛 Describe the bug

In the train_step method of a Callback class:

def train_step(self, state: State, data: TrainBatch) -> None:
  global_step = state.train_state.progress.num_steps_completed
  is_last_batch = state.train_state.is_last_batch # error

Error:

AttributeError:'PhaseState' object has no attribute 'is_last_batch'
train_state._step_output = train_unit.train_step(state, step_input)

Versions

Hi @daniellepintz Why is is_last_batch removed in #367 ? This wont pass the test defined in https://github.com/pytorch/tnt/blob/9b3b7b1a3c0cfa8354bd459fe84a46a03b2754f5/tests/framework/test_auto_unit.py#L901 What is the correct way to check is_last_batch? Thank you in advance!

The env:

torchtnt==0.1.0
torcheval==0.0.6
torchsnapshot==0.1.0
### Tasks

yiminglin-ai avatar Jun 14 '23 13:06 yiminglin-ai