taichi
taichi copied to clipboard
Add support for conversion of torch scalar to taichi scalar
Torch scalars are 0-dimensional tensors.
Currently taichi cannot convert a zero dimensional tensor (e.g. torch.tensor(2, dtype=torch.int32)
) to taichi scalar (e.g. ti.int32
).
Example
Suppose you have this fill
kernel:
import torch
import taichi as ti
@ti.kernel
def fill(out: ti.types.ndarray(dtype=ti.int32), value: ti.int32):
for I in ti.grouped(out):
out[I] = value
We can pass a python scalar
out = torch.empty((10,), dtype=torch.int32)
value = 2 # python scalar
fill(out, value)
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=torch.int32)
We can also pass a numpy
scalar
out = np.empty((10,), dtype=np.int32)
value = np.int32(2) # numpy scalar
fill(out, value)
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32)
but we cannot pass a torch
scalar
out = torch.empty((10,), dtype=torch.int32)
value = torch.tensor(2, dtype=torch.int32) # torch scalar
fill(out, value)
TaichiRuntimeTypeError:
Argument 1 (type=<class 'torch.Tensor'>) cannot be converted into required type i32
As you just mentioned, "torch-scalar" is a "zero-dimension" tensor, so basically you still need to treat it as a Ndarray in Taichi:
import torch
import taichi as ti
ti.init(arch=ti.cpu)
@ti.kernel
def fill(out: ti.types.ndarray(), value: ti.types.ndarray()):
for I in ti.grouped(out):
out[I] = value[None]
out = torch.empty((10,), dtype=torch.int32)
value = torch.tensor(2, dtype=torch.int32) # torch scalar
fill(out, value)
Try this out.