pytorch_sparse
pytorch_sparse copied to clipboard
OOM while using index_select on a huge sparse tensor
Hi,
I am trying to get the value from the given index (on the first dimension). However, my sparse tensor is quite large, and due to this reason old_colptr, row, value = src.csc() will encounter OOM issue.
The implementation from the torch is too slow to use, so, I am wondering if you have a better way to solve the issue.
My Code:
# index is a tensor with size of [2 * q_size * k_size]
# q_size and k_size would be around 27000
# since I want to get the value from the (i, j) but there is no such method
# so, I make my sparse tensor [1, q_size * k_size], and then select index on dim=1
index_offset = torch.tensor([q_size, 1], device=device)
flatten_index = (index.t() * index_offset).t().sum(0)
result = torch_sparse.index_select(my_sparse_tensor, 1, flatten_index)
Thank you in advance.
Thanks for this issue. Can you briefly explain your code above? I have a hard time understanding this. It looks like index is a (fully-connected) vector that gets transposed later on.
Nonetheless, efficient indexing in the first dimension requires CSC representation (which will require some memory to compute), so I am not yet sure where we can effectively save memory here.
Since sparse tensor doesn't support sparse_tensor[sequence], so I had to make both of my sparse tensor and index be the size of [1, q_size * k_size], and leverage index_select get all the values from the dim=1.
So, I firstly turn 2d index to 1d, the corresponding position will also become 1d as well as sparse tensor.
In sum, the sparse tensor would be [1, q_size * k_size], and index would be [length of index], and I want to get the values from these positions.
After some code tracing, I've spotted the issue is from colptr = self._col.new_zeros(self._sparse_sizes[1] + 1) while calling csc(). It seems like it would construct an extremely huge zeros vector. So, I am wondering if we can surpass this issue?
Thank you!
Mmh, yes, the CSC/CSR representations will create a tensor based on the sparse tensor sizes. One alternative is that you leverage torch.searchsorted on the original col representation. This way, there is no need to create this huge CSC representation.
Ok, I got your point, I will give it a try. Thank you.
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?