[RFC] Adding CUDA Graphs conditional node support with user-friendly APIs
I’m working on CuPy’s CUDA Graph conditional node support and considering adding a more user-friendly graph constructing API to CuPy.
I have two possible plans to achieve this goal:
- (A) “with” API
- (B) Functional API
(A) “with” API
pseudo code
import cupy
from contextlib import contextmanager
# Add conditional node to capturing graph and
# switch to different stream to capture body graph
@contextmanager
def while_loop():
...
@contextmanager
def if_else():
...
a: cupy.ndarray = ...
b: cupy.ndarray = ...
with capture_context() as ctx:
...
with ctx.while_loop( # while loop
cond_fn=lambda: cupy.all(a == b),
cond_fn_args=() # Optional
):
... # while loop body
with ctx.if_else(cond_fn=lambda: cupy.all(a != b)) as cond:
with cond.if_():
...
with cond.else_():
# Should we consider `else` body support?
# Currently conditional node does not support `else`
# but seems to add support in the near future.
...
graph = ctx.get_graph() # get captured graph
graph.launch()
Pros
- Considering the construction of CUDA graphs through stream capture, this approach provides a natural API.
- Allows for more concise writing of nested conditional nodes.
Cons
- Defining a loop using the “with” syntax in Python may not be intuitive.
- Adding a support for an "else" body can result in an awkward or unnatural syntax.
(B) Functional API
Is this the same approach with torch._higher_order_ops.while?
pseudo code
# Prepare GraphConverter class to hold graph capturing state
class GraphConverter:
def __init__(self):
self.state = ... # need to have a global state
# Function to define while loop
def while_loop(self, cond_fn):
def inner_func(body_fn, body_args=None):
... # operation to construct graph
return inner_func
def condition(self, cond_fn):
def inner_func(true_fn, false_fn=None):
...
return inner_func
def convert(self, func):
...
return func
gc = GraphConverter()
@gc.convert # A decorator to convert function to CUDA graph
def target(a: cupy.ndarray, b: cupy.ndarray):
# While
def while_fn(args):
...
gc.while_loop(cond_fn=lambda: cupy.all(a == b))(while_fn, body_args)
# You can also write as follows
@gc.while_loop(cond_fn=lambda: cupy.all(a == b))
def while_fn2():
nonlocal a, b
...
# if-then-else
def true_fn():
...
def false_fn():
...
gc.condition(cond_fn)(
true_fn=true_fn,
false_fn=false_fn # easier to add `else` support
)
# or
@gc.condition(cond_fn)
def true_fn2():
...
target(a, b) # construct and execute graph at the same time
Pros
- Simplifies the addition of "else" support when the conditional node supports an "else" body.
- Prevents issues related to unintentional variable overwriting and improper resource management, as functions naturally encapsulate variable scope.
Cons
- Writing deeply nested programs can become complex and cumbersome, although using decorators can help alleviate this issue to some extent.
- Requires the use of the "nonlocal" declaration or passing variables as arguments to access outer-scope variables.
Current status
I began implementing the API using the functional API approach (Plan B). However, switching to Plan A shouldn't be too time-consuming since the core implementation can be shared between both approaches.
I would like to ask your thoughts on these two ideas. Please let us know your preferences, potential advantages, pitfalls, and any other insights you might have.
xref #6290
Thanks for putting thoughts in this. A few drive-by comments:
- I don't think we need to support
elsenodes, because without compiler help we have no way to validate if the user-providedelseconditional function is strictly identical to the negate of theifcounterpart. Users who need anelsenode should just write anotherifnode. - Proposal A is fine with me if this is a concern:
Defining a loop using the “with” syntax in Python may not be intuitive.
because in order to use stream capture to build a graph that contains conditional nodes, we have to walk through all the nodes at least once, unconditionally. The with syntax allows that.
- Proposal A's
ctxusage to get a graph seems a bit odd to me. Usually after leaving the context manager, the context object is invalidated. (xref: https://github.com/cupy/cupy/pull/8502#issuecomment-2311671823) But it is probably just a nitpick. - I don't like
gcin Proposal B, as it could be confusing with the garbage collector. How about calling it, say, "GraphBuilder" (and aliasing it togb)? - When using the decorator approach in Proposal B, such as
@gc.while_loopand@gc.condition, it seems to we still need to callwhile_fn2andtrue_fn2intargetin order to have them executed at least once and captured. Is it an oversight?
In any approach, my biggest question is how you plan to generate additional kernels that take an cudaGraphConditionalHandle as argument and call the device function cudaGraphSetConditional to set the value based on the user-provided Python conditional function in the kernel. It seems to me either using CuPy JIT or CuPy kernel fusion is needed?
I see the primary benefit of API (B) is that we can "simulate" the graph conditionals on host-side. This should ease migrating existing CuPy code to CUDA Graph.
def target_func():
x = ...
def loop_body(x):
x += 1
while_loop(lambda x: x < 10, loop_body, args=(x,))
# Run everything eagerly (by evaluating the result of loop conditional function (`x < 10`) on host) - for debugging purposes only
capture_and_launch(target_func, simulate=True)
# Run using a real CUDA Graph
capture_and_launch(target_func)
my biggest question is how you plan to generate additional kernels that take an
cudaGraphConditionalHandleas argument and call the device functioncudaGraphSetConditionalto set the value based on the user-provided Python conditional function in the kernel.
The initial implementation will require an additional kernel invocation that only calls cudaGraphSetConditional(x) where x is a scalar (on GPU) returned by the conditional function.