pykokkos
pykokkos copied to clipboard
PyTorch Support
When I try to pass a PyTorch tensor into a workunit as I would a CuPy array, as in the following script,
import torch
import pykokkos as pk
@pk.workunit
def work(wid, a):
a[wid] = a[wid] + 1
def main():
N = 10
a = torch.ones(N)
pk.set_default_space(pk.Cuda)
pk.parallel_for("work", 10, work, a=a)
print(a)
main()
I am met with the error
Traceback (most recent call last):
File "/work/09661/gkk345/ls6/3dcapsules/python/development/gridding/tmp4.py", line 15, in <module>
main()
File "/work/09661/gkk345/ls6/3dcapsules/python/development/gridding/tmp4.py", line 12, in main
pk.parallel_for("work", 10, work, a=a)
File "/work/09661/gkk345/ls6/pykokkos/pykokkos/interface/parallel_dispatch.py", line 158, in parallel_for
runtime_singleton.runtime.run_workunit(
File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 153, in run_workunit
return self.execute_workunit(name, policy, workunit, operation, parser, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 199, in execute_workunit
members: PyKokkosMembers = self.precompile_workunit(workunit, execution_space, updated_decorator, updated_types, types_signature, restrict_views, restrict_signature, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 86, in precompile_workunit
members: PyKokkosMembers = self.compiler.compile_object(module_setup,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/compiler.py", line 178, in compile_object
entity.AST = parser.fix_types(entity, updated_types)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/parsers/parser.py", line 170, in fix_types
arg_obj.annotation = self.get_annotation_node(update_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/parsers/parser.py", line 283, in get_annotation_node
raise ValueError(f"Type inference for {type} is not supported")
ValueError: Type inference for Tensor is not supported
It would be very helpful for integration purposes if PyKokkos workunits supported PyTorch types.