Fixes a bug when assigning a function to an intermediate variable
Category
- [ ] New feature
- [x] Bugfix
- [ ] Breaking change
- [ ] Refactoring
- [ ] Documentation
- [ ] Other (please explain)
Description
Fixes a bug when assigning a function to an intermediate variable. Warp's code generator currently doesn't handle this pattern correctly. When a function is returned, the code generator replaces the expression with a special AST node (ast.Name) with an identifier __warp_func__ and attaches the actual function object to it using an attribute warp_func.
In the example repro below from the doc, when the code generator later encounters the name func in output[tid] = func(a, b), it tries to resolve func but fails because it doesn't find __warp_func__ in its symbol table, leading to the error.
NOTE: All tests pass, including the newly added test. However, please review carefully. Fingers crossed my tweaks don't unleash chaos.
import warp as wp
@wp.func
def do_add(a: float, b: float):
return a + b
@wp.func
def do_sub(a: float, b: float):
return a - b
@wp.func
def do_mul(a: float, b: float):
return a * b
op_handlers = {
"add": do_add,
"sub": do_sub,
"mul": do_mul,
}
inputs = wp.array([[1, 2], [3, 0]], dtype=wp.float32)
outputs = wp.empty(2, dtype=wp.float32)
for op in op_handlers.keys():
@wp.kernel
def operate(input: wp.array(dtype=inputs.dtype, ndim=2), output: wp.array(dtype=wp.float32)):
tid = wp.tid()
a, b = input[tid, 0], input[tid, 1]
# retrieve the right function to use for the captured dtype variable
output[tid] = wp.static(op_handlers[op])(a, b) # this works (as per the docs)
# ERROR: But below does not work unexpectedly (even though it should be equivalent)
# func = wp.static(op_handlers[op])
# output[tid] = func(a, b) # this does not work
wp.launch(operate, dim=2, inputs=[inputs], outputs=[outputs])
print(outputs.numpy())
Changelog
- Allow functions to be correctly assigned to variables in Warp kernels.
Before your PR is "Ready for review"
- [x] Do you agree to the terms under which contributions are accepted as described in Section 9 the Warp License?
- [x] Have you read the Contributor Guidelines?
- [x] Have you written any new necessary tests?
- [ ] Have you added or updated any necessary documentation?
- [x] Have you added any files modified by compiling Warp and building the documentation to this PR (.e.g.
stubs.py,functions.rst)? - [x] Does your code pass
ruff checkandruff format --check?