functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Output with shape [shape_1, shape_2] does not match the broadcast shape [batch. shape_1, shape_2]

Open xmser opened this issue 2 years ago • 1 comments

I want to use vamp to batch calculate the matrix inverse by Neumann method(a iteration method), like $matrix^{-1} \approx \sum_{k}[I-matrix]^{k}, k \to \infty$. But when I use the following code to calculate a batch matrix [20, 640 640] to get its matrix_inverse, I get the following error. batch_error

this is my code: `

    def Neumann(matrix, times=100):
        ans = torch.eye(matrix.shape[0]).cuda()
        multipler = ans - matrix
        temp = torch.eye(matrix.shape[0]).cuda()
        result = torch.zeros_like(temp).cuda()
        for steps in range(times):
            result += temp
            temp = temp.mm(multipler)
        del ans
        del multipler
        del temp
        return result

    def get_single_inverse(matrix):
        return Neumann(matrix=matrix)

   batch_get_inverse = vmap(get_single_inverse)
   total_lenth = partial_xx_tensor.shape[0]
   batch_numbers = math.ceil(total_lenth / mini_batch)
   partial_xx_inv_list = [batch_get_inverse(partial_xx_tensor[dex * mini_batch : min((dex + 1) * mini_batch, total_lenth)]) for dex in range(batch_numbers)]
        
        

`

I want to use the "batch_get_inverse" to batch get the matrix inverse, where the partial_xx_tensor is like [2000, 604, 640], I split this tensor into mini_batch such as [20, 640, 640] to get the inverse.

Thank you very much !!

xmser avatar Jul 10 '22 14:07 xmser

@xmser thanks for the issue. I'm trying to run the repro above -- what is the value of mini_batch and what is the shape of partial_xx_tensor?

zou3519 avatar Jul 11 '22 15:07 zou3519