lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Add prim operators to query/update CUDA default RNG state

Open kiya00 opened this issue 1 year ago • 6 comments

Background: Currently, the mask calculated for dropout in the forward trace is passed to the backward trace, which takes too much memory. We want Thunder to save the seed and offset from random number generation to recompute in the backward pass.

Future steps:

  • [x] Add the prim operators to query/update the default CUDA RNG state(this PR)
  • [ ] Add transformation to replacing uniform with uniform_philox (PR)
  • [ ] Make sure the rematerialization works properly for Dropout(PR)
  • [ ] Fix #231, so that the seed/offset passed to backward trace is actual runtime value

This PR is the first step in fixing #114.

  • It adds the prim operators need to query/update the default CUDA RNG state the same way as nvfuser.
  • An example of the expected trace after replacing uniform with uniform_philox:
# def func(a):
#   b = ltorch.uniform_like(a, device=a.device, dtype=a.dtype)
#   return b * a

# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t2: "cpu ui8[16]"
  (i3, i4) = unpack_rng_state_prim_impl(t2)
  del t2
  [t1] = nvFusion0(a, i3, i4)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"
  t6 = update_rng_state_prim_impl(i3, i4)  # t6: "cpu ui8[16]"
  del i3, i4
  set_rng_state_prim_impl(t6, devices.Device("cuda:0"))
  del t6
  return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (0, 0))

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  _, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  clear_collection(cotangents)
  del cotangents
  i3, i4, = C1
  clear_collection(C1)
  del C1
  [t12] = nvFusion0(i3, i4, t2)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t12 = prims.mul(t0, t2)  # t12: "cuda:0 f32[2, 2]"
  del i3, i4, t2
  return (t12,)

cc @apaz-cli

kiya00 avatar Apr 24 '24 15:04 kiya00

I would prefer that we functionalize the RNG state handling within thunder and I wonder if this could be achieved with moderate effort (so the problem is similar to #145 but the solution likely much easier).

t-vi avatar Apr 26 '24 10:04 t-vi

I would prefer that we functionalize the RNG state handling within thunder and I wonder if this could be achieved with moderate effort (so the problem is similar to #145 but the solution likely much easier).

Do you mean that state handling(this pr manually passes in the state as dataflow to keep the order of rng ops)? I didn't know that, do we have a general way to handle this?

kiya00 avatar Apr 29 '24 08:04 kiya00

triage review —

There's some interesting stuff happening here! Before moving forward, can we do a design review with the broader team? We think lots of people would be interested

@IvanYashchuk can help setup a design review

mruberry avatar Apr 29 '24 19:04 mruberry

Sure, let me know your availability

kiya00 avatar Apr 30 '24 11:04 kiya00

Hi @t-vi @mruberry , I think it's ready to merge

kiya00 avatar May 13 '24 09:05 kiya00

Mike asked for the design review meeting in https://github.com/Lightning-AI/lightning-thunder/pull/264#issuecomment-2083466316. We shouldn't merge this PR until more people are happy with the proposed features.

Here's the design doc that Yan wrote (NVIDIA-only link).

This PR is about adding special primitives that query and advance hidden PyTorch's RNG state and would allow Thunder to generate code like in test_uniform_philox_with_rng_state_prims:

    def func(shape, dtype, device):
        seed0, offset0 = prims.get_and_update_rng_state(None, None, device=device)
        out1 = ltorch.uniform_philox(shape, device=device, dtype=dtype, seed=seed0, offset=offset0)

        seed1, offset1 = prims.get_and_update_rng_state(seed0, offset0, device)
        out2 = ltorch.uniform_philox(shape, device=device, dtype=dtype, seed=seed1, offset=offset1)

        return out1, out2

This PR doesn't implement the possibility of capturing torch.cuda.get_rng_state, torch.cuda.set_rng_state, so these operations are meant to be introduced into the trace as a Thunder trace transform.

IvanYashchuk avatar May 13 '24 10:05 IvanYashchuk

Hi @mruberry @jjsjann123 @IvanYashchuk , I modified it according our design review discussion, could you take a look?

kiya00 avatar May 28 '24 12:05 kiya00

Hi @IvanYashchuk @t-vi , I think it's ready to merge, set the numberproxy to none works fine in this PR, but will cause the input to nvfusion in the backward to be none, which only affects the uniform transform PR, that pr needs the numberproxy ready for the test case in anyway

kiya00 avatar May 29 '24 13:05 kiya00