taichi
taichi copied to clipboard
Auto Difference with Matrix Calculation
My task is to test the auto gradient with the matrix operation.
Here are my code: `import taichi as ti import math ti.init(arch=ti.gpu,debug=True)
N = 10 Dim3=3 Dim2=2 Dim1=1 positions = ti.Vector.field(Dim2, dtype=ti.f32, shape=(N,),needs_grad=True)# define a 2D vector field pos2Ds=ti.Vector.field(Dim1, dtype=ti.f32, shape=(N,),needs_grad=True) realpos=ti.field(dtype=ti.f32, shape=(N,),needs_grad=True) L=ti.field(dtype=ti.f32, shape=(), needs_grad=True)
theta = math.pi / 4 # rotate 45 degree rotation_matrix = ti.Matrix([[ti.cos(theta), -ti.sin(theta)], [ti.sin(theta), ti.cos(theta)]])# Define a rotation matrix
@ti.kernel def init(): for i in ti.grouped(positions): positions[i] = [1.0,1.0] @ti.kernel def transform(): for i in positions: pos2Ds[i] = (rotation_matrix @ positions[i]).y @ti.kernel def comp_loss(): for i in ti.grouped(pos2Ds): L[None]+=(realpos[i]-pos2Ds[i].x)
init()
with ti.ad.Tape(loss=L,validation=True):
transform()
comp_loss(
)
print(L[None])
print(positions.grad)`
But I got the error:
RuntimeError Traceback (most recent call last)
Cell In[1], line 36
33 transform()
34 # Kernel invocations in this scope will later contribute to partial derivatives of
35 # U with respect to input variables such as x.
---> 36 comp_loss(
37 ) # The tape will automatically compute dU/dx and save the results in x.grad
38 print(L[None])
39 print(positions.grad)
File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:1103, in _kernel_impl.
File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\shell.py:27, in _shell_pop_print.
File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:1035, in Kernel.call(self, *args, **kwargs) 1033 key = self.ensure_compiled(*args) 1034 kernel_cpp = self.compiled_kernels[key] -> 1035 return self.launch_kernel(kernel_cpp, *args)
File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:966, in Kernel.launch_kernel(self, t_kernel, *args) 964 if impl.get_runtime().print_full_traceback: 965 raise e --> 966 raise e from None 968 ret = None 969 ret_dt = self.return_type
File c:\Users\LID0E\AppData\Local\miniconda3\envs\Gaussin-Taichi\lib\site-packages\taichi\lang\kernel_impl.py:959, in Kernel.launch_kernel(self, t_kernel, *args) 957 prog = impl.get_runtime().prog 958 # Compile kernel (& Online Cache & Offline Cache) --> 959 compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel) 960 # Launch kernel 961 prog.launch_kernel(compiled_kernel_data, launch_ctx)
RuntimeError: [taichi/ir/ir.h:taichi::lang::IRNode::as@248] Assertion failure: is<T>()
How can I fix this problem?