functorch
functorch copied to clipboard
Swapping 2 columns in a 2d tensor
I have a function tridiagonalization to tridiagonalize matrix (2d tensor), and I want to map it to batch. It involves a for loop and on each iteration a permutation of 2 columns and 2 rows inside it. I do not understand how to permute 2 columns without errors. So my code for rows works and looks as follows:
row_temp = matrix_stacked[pivot[None]][0]
matrix_stacked[[pivot[None]][0]] = matrix_stacked[i+1].clone()
matrix_stacked[i+1] = row_temp
Where pivot is a tensor and i is a Python integer variable. For columns I have something like this:
column_temp = matrix_stacked[:, [pivot[None]][0]]
matrix_stacked[:, [pivot[None]][0]] = matrix_stacked[:, [i+1]].clone()
matrix_stacked[:, i+1] = column_temp
It does not wotk because of issues with size. What should I do in order to permute i+1 and pivot columns?