cupy icon indicating copy to clipboard operation
cupy copied to clipboard

[RFC] Adding CUDA Graphs conditional node support with user-friendly APIs

Open so298 opened this issue 1 year ago • 3 comments

I’m working on CuPy’s CUDA Graph conditional node support and considering adding a more user-friendly graph constructing API to CuPy.
I have two possible plans to achieve this goal:

  • (A) “with” API
  • (B) Functional API

(A) “with” API

pseudo code

import cupy
from contextlib import contextmanager

# Add conditional node to capturing graph and
# switch to different stream to capture body graph
@contextmanager
def while_loop():
    ...
@contextmanager
def if_else():
    ...

a: cupy.ndarray = ...
b: cupy.ndarray = ...

with capture_context() as ctx:
    ...
    with ctx.while_loop( # while loop
        cond_fn=lambda: cupy.all(a == b),
        cond_fn_args=() # Optional
    ):
        ... # while loop body

    with ctx.if_else(cond_fn=lambda: cupy.all(a != b)) as cond:
        with cond.if_():
            ...
        with cond.else_():
            # Should we consider `else` body support?
            # Currently conditional node does not support `else`
            # but seems to add support in the near future.
            ...

graph = ctx.get_graph() # get captured graph
graph.launch()

Pros

  • Considering the construction of CUDA graphs through stream capture, this approach provides a natural API.
  • Allows for more concise writing of nested conditional nodes.

Cons

  • Defining a loop using the “with” syntax in Python may not be intuitive.
  • Adding a support for an "else" body can result in an awkward or unnatural syntax.

(B) Functional API

Is this the same approach with torch._higher_order_ops.while?

pseudo code

# Prepare GraphConverter class to hold graph capturing state
class GraphConverter:
    def __init__(self):
        self.state = ... # need to have a global state

    # Function to define while loop
    def while_loop(self, cond_fn):
        def inner_func(body_fn, body_args=None):
            ... # operation to construct graph
        return inner_func

    def condition(self, cond_fn):
        def inner_func(true_fn, false_fn=None):
            ...
        return inner_func

    def convert(self, func):
        ...
        return func

gc = GraphConverter()
@gc.convert # A decorator to convert function to CUDA graph
def target(a: cupy.ndarray, b: cupy.ndarray):
    # While
    def while_fn(args):
        ...
    gc.while_loop(cond_fn=lambda: cupy.all(a == b))(while_fn, body_args)

    # You can also write as follows
    @gc.while_loop(cond_fn=lambda: cupy.all(a == b))
    def while_fn2():
        nonlocal a, b
        ...

    # if-then-else
    def true_fn():
        ...
    def false_fn():
        ...
    gc.condition(cond_fn)(
        true_fn=true_fn,
        false_fn=false_fn # easier to add `else` support
    )
    # or
    @gc.condition(cond_fn)
    def true_fn2():
        ...

target(a, b) # construct and execute graph at the same time

Pros

  • Simplifies the addition of "else" support when the conditional node supports an "else" body.
  • Prevents issues related to unintentional variable overwriting and improper resource management, as functions naturally encapsulate variable scope.

Cons

  • Writing deeply nested programs can become complex and cumbersome, although using decorators can help alleviate this issue to some extent.
  • Requires the use of the "nonlocal" declaration or passing variables as arguments to access outer-scope variables.

Current status

I began implementing the API using the functional API approach (Plan B). However, switching to Plan A shouldn't be too time-consuming since the core implementation can be shared between both approaches.

I would like to ask your thoughts on these two ideas. Please let us know your preferences, potential advantages, pitfalls, and any other insights you might have.

so298 avatar Aug 29 '24 05:08 so298

xref #6290

kmaehashi avatar Sep 07 '24 09:09 kmaehashi

Thanks for putting thoughts in this. A few drive-by comments:

  1. I don't think we need to support else nodes, because without compiler help we have no way to validate if the user-provided else conditional function is strictly identical to the negate of the if counterpart. Users who need an else node should just write another if node.
  2. Proposal A is fine with me if this is a concern:

Defining a loop using the “with” syntax in Python may not be intuitive.

because in order to use stream capture to build a graph that contains conditional nodes, we have to walk through all the nodes at least once, unconditionally. The with syntax allows that.

  1. Proposal A's ctx usage to get a graph seems a bit odd to me. Usually after leaving the context manager, the context object is invalidated. (xref: https://github.com/cupy/cupy/pull/8502#issuecomment-2311671823) But it is probably just a nitpick.
  2. I don't like gc in Proposal B, as it could be confusing with the garbage collector. How about calling it, say, "GraphBuilder" (and aliasing it to gb)?
  3. When using the decorator approach in Proposal B, such as @gc.while_loop and @gc.condition, it seems to we still need to call while_fn2 and true_fn2 in target in order to have them executed at least once and captured. Is it an oversight?

In any approach, my biggest question is how you plan to generate additional kernels that take an cudaGraphConditionalHandle as argument and call the device function cudaGraphSetConditional to set the value based on the user-provided Python conditional function in the kernel. It seems to me either using CuPy JIT or CuPy kernel fusion is needed?

leofang avatar Sep 09 '24 13:09 leofang

I see the primary benefit of API (B) is that we can "simulate" the graph conditionals on host-side. This should ease migrating existing CuPy code to CUDA Graph.

def target_func():
  x = ...
  def loop_body(x):
    x += 1
  while_loop(lambda x: x < 10, loop_body, args=(x,))

# Run everything eagerly (by evaluating the result of loop conditional function (`x < 10`) on host) - for debugging purposes only
capture_and_launch(target_func, simulate=True)

# Run using a real CUDA Graph
capture_and_launch(target_func)

my biggest question is how you plan to generate additional kernels that take an cudaGraphConditionalHandle as argument and call the device function cudaGraphSetConditional to set the value based on the user-provided Python conditional function in the kernel.

The initial implementation will require an additional kernel invocation that only calls cudaGraphSetConditional(x) where x is a scalar (on GPU) returned by the conditional function.

kmaehashi avatar Sep 16 '24 08:09 kmaehashi