trlx
trlx copied to clipboard
Faster & memory-efficient logprobs calculation
The current logprobs_of_labels
computes logprobs using a log_softmax
followed by a gather
. When the input logits is not contiguous, the log_softmax
will make a copy of the logits, which is very large (batch_size * seq_len * vocab_size can be 32 * 2048 * 64000 * 2B = 8GB for typical settings).
This PR directly feeds the contiguous logits into log_softmax
so as to reduce the peak cuda memory and remove redundant copy.
Test script:
import torch
from torch.utils.benchmark import Timer
from trlx.utils.modeling import logprobs_of_labels
def perf():
batch_size, seq_len, vocab_size = 32, 2048, 64000
logits = torch.randn((batch_size, seq_len, vocab_size), dtype=torch.half, device='cuda')
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device='cuda')
# correctness
assert torch.allclose(logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:]), logprobs_of_labels(logits, input_ids[:, 1:]))
# peak memory test
torch.cuda.empty_cache()
logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])
print(f'original allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')
torch.cuda.empty_cache()
logprobs_of_labels(logits, input_ids[:, 1:])
print(f'optimized allocated: {torch.cuda.memory_allocated() / 1e9:.3f} GB, reserved: {torch.cuda.memory_reserved() / 1e9:.3f} GB')
# speed test
timer = Timer(stmt="logprobs_of_labels(logits[:, :-1, :], input_ids[:, 1:])", globals={**globals(), **locals()})
elapsed_org = timer.timeit(100).mean
print(f'original costs: {elapsed_org:.4f} s')
timer = Timer(stmt="logprobs_of_labels(logits, input_ids[:, 1:])", globals={**globals(), **locals()})
elapsed_opt = timer.timeit(100).mean
print(f'optimized costs: {elapsed_opt:.4f} s')
perf()
Tested on a Tesla V100, method in this PR is both faster (1.6x speedup) and memory-efficient.
original allocated: 8.389 GB, reserved: 25.164 GB
optimized allocated: 8.389 GB, reserved: 16.779 GB
original costs: 0.0700 s
optimized costs: 0.0435 s
Codecov Report
Attention: 6 lines
in your changes are missing coverage. Please review.
Comparison is base (
91a0f43
) 43.58% compared to head (730d900
) 43.58%. Report is 1 commits behind head on main.
:exclamation: Current head 730d900 differs from pull request most recent head aa1031a. Consider uploading reports for the commit aa1031a to get more accurate results
Files | Patch % | Lines |
---|---|---|
trlx/models/modeling_nemo_ppo.py | 0.00% | 3 Missing :warning: |
trlx/trainer/accelerate_ppo_trainer.py | 57.14% | 3 Missing :warning: |
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@ Coverage Diff @@
## main #583 +/- ##
=======================================
Coverage 43.58% 43.58%
=======================================
Files 33 33
Lines 4974 4974
=======================================
Hits 2168 2168
Misses 2806 2806
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.