tutel/jit_kernels/sparse.py torch.float16 There is a bug in the calculation: the cuda calculation result is inconsistent with the CPU calculation result and the array is out of bounds
code :
import numpy as np import torch from tutel.jit_kernels import sparse as jit_kernel print(torch.version) def moe_dispatch_bwd_gate(): samples=2 capacity=2 hidden=2 num_experts=1 indices = [0,0] locations = [0,0] input = [0.4946, -0.0043, 0.5386, -0.8354] dispatch = [0.7085, 0.8257, -0.1455, -0.1788] #int32 indices_t = np.asarray(indices,dtype=np.int32) locations_t = np.asarray(locations,dtype=np.int32) #float / half input_t = np.asarray(input,dtype=np.float16) dispatch_t = np.asarray(dispatch,dtype=np.float16) indices_gpu = torch.from_numpy(indices_t).cuda() locations_gpu = torch.from_numpy(locations_t).cuda() input_gpu = torch.from_numpy(input_t).cuda() dispatch_gpu = torch.from_numpy(dispatch_t).cuda() print("cuda:") print("indices_gpu:",indices_gpu) print("locations_gpu:",locations_gpu) print("input_gpu:",input_gpu) print("dispatch_gpu:",dispatch_gpu) # call gpu func grad_gates = torch.zeros([samples], dtype=input_gpu.dtype, device=input_gpu.device) moe_dispatch_bwd_gate = jit_kernel.create_backward_gate(input_gpu.dtype, input_gpu.is_cuda) moe_dispatch_bwd_gate(grad_gates, indices_gpu, locations_gpu, input_gpu, dispatch_gpu, extra=[samples, hidden, capacity]) print("grad_gates:",grad_gates) # call cpu func input_t = np.asarray(input,dtype=np.float32) dispatch_t = np.asarray(dispatch,dtype=np.float32) indices_cpu = torch.from_numpy(indices_t) locations_cpu = torch.from_numpy(locations_t) input_cpu = torch.from_numpy(input_t) print("cpu:") # print("input_cpu:",input_cpu) dispatch_cpu = torch.from_numpy(dispatch_t) grad_gates_cpu = torch.zeros([samples], dtype=input_cpu.dtype, device=input_cpu.device) moe_dispatch_bwd_gate = jit_kernel.create_backward_gate(input_cpu.dtype, input_cpu.is_cuda) moe_dispatch_bwd_gate(grad_gates_cpu, indices_cpu, locations_cpu, input_cpu, dispatch_cpu, extra=[samples, hidden, capacity]) print("grad_gates_cpu:",grad_gates_cpu) if name == 'main': moe_dispatch_bwd_gate()
Problem: cuda calculation result is inconsistent with CPU calculation result: cuda:[0.4180, 0.0000] cpu:[ 0.3469, -0.3082]
Cuda calculation process analysis:
When index=0, calculate the gradient of the first gate
Due to dispatched_ Input and reshaded_ Input is of type half2, which is equivalent to float pointer
Therefore, when i=0, the subscript index * (hidden)+i=0 of the distribution, and the subscript index * (hidden)+i=0 of the input, take the first two half data, and accumulate the result of the calculation_ gates1_ s_ On rf
Read value: patch=[0.7085, 0.8257], input=[0.4946, -0.0043]
I=0 Calculation result: grad_ gates1_ s_ rf = 0.7085 * 0.4946 + 0.8257 * (-0.0043) = 0.34687359
When i=1, the subscript index * (hidden)+i=1 of the distribution, and the subscript index * (hidden)+i=1 of the input, take the last two half data, and also add it to the first gate gradient
Read value: patch=[-0.1455, -0.1788], input=[0.5386, -0.8354]
I=1 calculation result grad_ gates1_ s_ rf += (0.5386 * (-0.1455) + (-0.8354) * (-0.1788) = 0.07100322)
Last grad_ gates1_ s_ rf = 0.34687359 + 0.07100322 = 0.41787681
When index=1, the gradient of the second gate is calculated. The initial subscript of input is 2. The array access is out of bounds. The illegal address value may be 0, resulting in the second gradient result of 0
Hi, thanks for your info. According to tracing, this is not a bug, but your code doesn't use it in a correct way:
CUDA's evaluation from your code is based on fp16x2, so the hidden_size value fed to that kernel should be divided by 2 as well (see https://github.com/microsoft/tutel/blob/main/tutel/impls/fast_dispatch.py#L95).
In other words, you should change your code from:
moe_dispatch_bwd_gate(grad_gates, indices_gpu, locations_gpu, input_gpu, dispatch_gpu, extra=[samples, hidden, capacity])
into
moe_dispatch_bwd_gate(grad_gates, indices_gpu, locations_gpu, input_gpu, dispatch_gpu, extra=[samples, hidden if input_gpu.dtype is not torch.float16 else hidden // 2, capacity])