warp icon indicating copy to clipboard operation
warp copied to clipboard

Fixes a bug when assigning a function to an intermediate variable

Open mehdiataei opened this issue 1 year ago • 0 comments

Category

  • [ ] New feature
  • [x] Bugfix
  • [ ] Breaking change
  • [ ] Refactoring
  • [ ] Documentation
  • [ ] Other (please explain)

Description

Fixes a bug when assigning a function to an intermediate variable. Warp's code generator currently doesn't handle this pattern correctly. When a function is returned, the code generator replaces the expression with a special AST node (ast.Name) with an identifier __warp_func__ and attaches the actual function object to it using an attribute warp_func.

In the example repro below from the doc, when the code generator later encounters the name func in output[tid] = func(a, b), it tries to resolve func but fails because it doesn't find __warp_func__ in its symbol table, leading to the error.

NOTE: All tests pass, including the newly added test. However, please review carefully. Fingers crossed my tweaks don't unleash chaos.

import warp as wp

@wp.func
def do_add(a: float, b: float):
    return a + b

@wp.func
def do_sub(a: float, b: float):
    return a - b

@wp.func
def do_mul(a: float, b: float):
    return a * b

op_handlers = {
    "add": do_add,
    "sub": do_sub,
    "mul": do_mul,
}

inputs = wp.array([[1, 2], [3, 0]], dtype=wp.float32)
outputs = wp.empty(2, dtype=wp.float32)

for op in op_handlers.keys():

    @wp.kernel
    def operate(input: wp.array(dtype=inputs.dtype, ndim=2), output: wp.array(dtype=wp.float32)):
        tid = wp.tid()
        a, b = input[tid, 0], input[tid, 1]
        # retrieve the right function to use for the captured dtype variable
        output[tid] = wp.static(op_handlers[op])(a, b) # this works (as per the docs)
        # ERROR: But below does not work unexpectedly (even though it should be equivalent)
        # func = wp.static(op_handlers[op])
        # output[tid] = func(a, b) # this does not work

    wp.launch(operate, dim=2, inputs=[inputs], outputs=[outputs])

print(outputs.numpy())

Changelog

  • Allow functions to be correctly assigned to variables in Warp kernels.

Before your PR is "Ready for review"

  • [x] Do you agree to the terms under which contributions are accepted as described in Section 9 the Warp License?
  • [x] Have you read the Contributor Guidelines?
  • [x] Have you written any new necessary tests?
  • [ ] Have you added or updated any necessary documentation?
  • [x] Have you added any files modified by compiling Warp and building the documentation to this PR (.e.g. stubs.py, functions.rst)?
  • [x] Does your code pass ruff check and ruff format --check?

mehdiataei avatar Oct 10 '24 22:10 mehdiataei