pykokkos icon indicating copy to clipboard operation
pykokkos copied to clipboard

PyTorch Support

Open kennykos opened this issue 7 months ago • 1 comments

When I try to pass a PyTorch tensor into a workunit as I would a CuPy array, as in the following script,

import torch
import pykokkos as pk

@pk.workunit
def work(wid, a):
    a[wid] = a[wid] + 1

def main():
    N = 10
    a = torch.ones(N)
    pk.set_default_space(pk.Cuda)
    pk.parallel_for("work", 10, work, a=a)
    print(a)

main()

I am met with the error

Traceback (most recent call last):
  File "/work/09661/gkk345/ls6/3dcapsules/python/development/gridding/tmp4.py", line 15, in <module>
    main()
  File "/work/09661/gkk345/ls6/3dcapsules/python/development/gridding/tmp4.py", line 12, in main
    pk.parallel_for("work", 10, work, a=a)
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/interface/parallel_dispatch.py", line 158, in parallel_for
    runtime_singleton.runtime.run_workunit(
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 153, in run_workunit
    return self.execute_workunit(name, policy, workunit, operation, parser, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 199, in execute_workunit
    members: PyKokkosMembers = self.precompile_workunit(workunit, execution_space, updated_decorator, updated_types, types_signature, restrict_views, restrict_signature, **kwargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/runtime.py", line 86, in precompile_workunit
    members: PyKokkosMembers = self.compiler.compile_object(module_setup,
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/compiler.py", line 178, in compile_object
    entity.AST = parser.fix_types(entity, updated_types)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/parsers/parser.py", line 170, in fix_types
    arg_obj.annotation = self.get_annotation_node(update_type)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/ls6/pykokkos/pykokkos/core/parsers/parser.py", line 283, in get_annotation_node
    raise ValueError(f"Type inference for {type} is not supported")
ValueError: Type inference for Tensor is not supported

It would be very helpful for integration purposes if PyKokkos workunits supported PyTorch types.

kennykos avatar Jul 02 '24 14:07 kennykos