aaaaammmmm

Results 4 comments of aaaaammmmm

Reproduced the issue. In OpenRLHF 0.5.7, after upgrading deepspeed from 0.15.0 to 0.16.0, my ppo demo encounters OOM ``` File "/root/miniconda3/lib/python3.10/site-packages/openrlhf/utils/deepspeed/deepspeed.py", line 129, in backward model.backward(loss) File "/root/miniconda3/lib/python3.10/site-packages/deepspeed/utils/nvtx.py", line 18,...

This commit https://github.com/deepspeedai/DeepSpeed/commit/cd20a3bbc7713908d7fb5fd7af4a91d52f126370 introduced the issue ![Image](https://github.com/user-attachments/assets/f5d06819-40d0-4f31-bcc4-c5b4b745a2db) Open reasoner zero identified the same issue. https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero/issues/13

I found that this issue occurs because the `ds_grads_remaining` cannot correctly count for the forward pass when using `gradient checkpointing`. ![Image](https://github.com/user-attachments/assets/99b69bf5-8c16-43c0-9e1b-0c21a36d2584) Below is the log for enable gradient checkpointing.: ![Image](https://github.com/user-attachments/assets/1c159dc4-5772-43be-92f5-a78c2eeefdf2)

![Image](https://github.com/user-attachments/assets/1b6dd5cf-2bd8-4a0e-bddc-aa4efa8d871c) same problem I guess you cloud add model_dtype in config to solve it