Add prim operators to query/update CUDA default RNG state
Background:
Currently, the mask calculated for dropout in the forward trace is passed to the backward trace, which takes too much memory. We want Thunder to save the seed and offset from random number generation to recompute in the backward pass.
Future steps:
- [x] Add the prim operators to query/update the default CUDA RNG state(this PR)
- [ ] Add transformation to replacing uniform with uniform_philox (PR)
- [ ] Make sure the rematerialization works properly for Dropout(PR)
- [ ] Fix #231, so that the seed/offset passed to backward trace is actual runtime value
This PR is the first step in fixing #114.
- It adds the prim operators need to query/update the default CUDA RNG state the same way as nvfuser.
- An example of the expected trace after replacing uniform with uniform_philox:
# def func(a):
# b = ltorch.uniform_like(a, device=a.device, dtype=a.dtype)
# return b * a
# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
# a: "cuda:0 f32[2, 2]"
t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0")) # t2: "cpu ui8[16]"
(i3, i4) = unpack_rng_state_prim_impl(t2)
del t2
[t1] = nvFusion0(a, i3, i4)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t1 = prims.mul(t0, a) # t1: "cuda:0 f32[2, 2]"
t6 = update_rng_state_prim_impl(i3, i4) # t6: "cpu ui8[16]"
del i3, i4
set_rng_state_prim_impl(t6, devices.Device("cuda:0"))
del t6
return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (0, 0))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
_, C1, = saved_for_backward
clear_collection(saved_for_backward)
del saved_for_backward
t2, = cotangents
clear_collection(cotangents)
del cotangents
i3, i4, = C1
clear_collection(C1)
del C1
[t12] = nvFusion0(i3, i4, t2)
# t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4) # t0: "cuda:0 f32[2, 2]"
# t12 = prims.mul(t0, t2) # t12: "cuda:0 f32[2, 2]"
del i3, i4, t2
return (t12,)
cc @apaz-cli
I would prefer that we functionalize the RNG state handling within thunder and I wonder if this could be achieved with moderate effort (so the problem is similar to #145 but the solution likely much easier).
I would prefer that we functionalize the RNG state handling within thunder and I wonder if this could be achieved with moderate effort (so the problem is similar to #145 but the solution likely much easier).
Do you mean that state handling(this pr manually passes in the state as dataflow to keep the order of rng ops)? I didn't know that, do we have a general way to handle this?
triage review —
There's some interesting stuff happening here! Before moving forward, can we do a design review with the broader team? We think lots of people would be interested
@IvanYashchuk can help setup a design review
Sure, let me know your availability
Hi @t-vi @mruberry , I think it's ready to merge
Mike asked for the design review meeting in https://github.com/Lightning-AI/lightning-thunder/pull/264#issuecomment-2083466316. We shouldn't merge this PR until more people are happy with the proposed features.
Here's the design doc that Yan wrote (NVIDIA-only link).
This PR is about adding special primitives that query and advance hidden PyTorch's RNG state and would allow Thunder to generate code like in test_uniform_philox_with_rng_state_prims:
def func(shape, dtype, device):
seed0, offset0 = prims.get_and_update_rng_state(None, None, device=device)
out1 = ltorch.uniform_philox(shape, device=device, dtype=dtype, seed=seed0, offset=offset0)
seed1, offset1 = prims.get_and_update_rng_state(seed0, offset0, device)
out2 = ltorch.uniform_philox(shape, device=device, dtype=dtype, seed=seed1, offset=offset1)
return out1, out2
This PR doesn't implement the possibility of capturing torch.cuda.get_rng_state, torch.cuda.set_rng_state, so these operations are meant to be introduced into the trace as a Thunder trace transform.
Hi @mruberry @jjsjann123 @IvanYashchuk , I modified it according our design review discussion, could you take a look?
Hi @IvanYashchuk @t-vi , I think it's ready to merge, set the numberproxy to none works fine in this PR, but will cause the input to nvfusion in the backward to be none, which only affects the uniform transform PR, that pr needs the numberproxy ready for the test case in anyway