ijepa
ijepa copied to clipboard
Why there is no **unscale_** when you use amp?
Your code is
if use_bfloat16:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
I think it should be
if use_bfloat16:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
Am I right?