Assigning value to Tensor(translate torch expression to Candle)
Hi, how should I translate this expression from Python(torch) to candle: Both count, mask and result are tensors:
mask = (count.squeeze(-1) > 0)
result[mask] = result[mask] / count[mask].repeat(1, C)
I tried it this way, but it seems wrong:
let mask = count.squeeze(D::Minus1)?.gt(0 as i64)?;
let masked_ans_result = ans_result.i(&mask)?;
let repeated_ans_count = ans_count.i(&mask)?.repeat(&[c])?;
let updated_ans_result = masked_ans_result.div(&repeated_ans_count)?;
ans_result.i(&mask)?.eq( &updated_ans_result)?;
Hi @wiktorkujawa! Let's break this down into 2 steps:
-
result[mask] / count[mask].repeat(1, C)
- Use the Candle
Tensor::gathermethod to extract values on a certain dimension, perhapsD::Minus1here?
let new_result = (result.gather(&mask, D::Minus1)? / count.gather(&mask)?.repeat(&[1, c])?)?;
- Assign back into
resultThis is a bit convoluted, so let's see our tool:
-
Tensor::scatter_addThe idea is that we want to zero out the part ofresultwhere we want to insert the new values, and then add them. Until Candle gets a dedicated op for this, this is the best we can do and it should be reasonably fast on CUDA/Metal.
let result_zeroed = result.scatter_add(&mask, result.neg()?)?;
let result = result.scatter_add(&mask, new_result)?;
So, the final code might look like this:
let new_result = (result.gather(&mask, D::Minus1)? / count.gather(&mask)?.repeat(&[1, c])?)?;
let result_zeroed = result.scatter_add(&mask, result.neg()?)?;
let result = result.scatter_add(&mask, new_result)?;
@EricLBuehler What about dimensions, both gather and scatter_add always require dimension value.(gather needs two arguments, and scatter_add needs three arguments). Seems that this Dim argument is not optional.
@wiktorkujawa what are the shapes of result, mask, and count?
@EricLBuehler Result and count are something like this:
result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
mask = (count.squeeze(-1) > 0)
where:
- H and W are typically equal 256 or 512(it's a texture size, like png or jpg size)
-
C = values.shape[-1]where values is tensor connected to some rgbs transformations(to rasterize visible Gaussians to image, it has quite a lot of transformations in the way)