candle icon indicating copy to clipboard operation
candle copied to clipboard

Assigning value to Tensor(translate torch expression to Candle)

Open wiktorkujawa opened this issue 1 year ago • 4 comments

Hi, how should I translate this expression from Python(torch) to candle: Both count, mask and result are tensors:

mask = (count.squeeze(-1) > 0)
result[mask] = result[mask] / count[mask].repeat(1, C)

I tried it this way, but it seems wrong:

let mask = count.squeeze(D::Minus1)?.gt(0 as i64)?;
let masked_ans_result = ans_result.i(&mask)?;
let repeated_ans_count = ans_count.i(&mask)?.repeat(&[c])?;
let updated_ans_result = masked_ans_result.div(&repeated_ans_count)?;        
ans_result.i(&mask)?.eq( &updated_ans_result)?;

wiktorkujawa avatar Jul 19 '24 22:07 wiktorkujawa

Hi @wiktorkujawa! Let's break this down into 2 steps:

  1. result[mask] / count[mask].repeat(1, C)
  • Use the Candle Tensor::gather method to extract values on a certain dimension, perhaps D::Minus1 here?
let new_result = (result.gather(&mask, D::Minus1)? / count.gather(&mask)?.repeat(&[1, c])?)?;
  1. Assign back into result This is a bit convoluted, so let's see our tool:
  • Tensor::scatter_add The idea is that we want to zero out the part of result where we want to insert the new values, and then add them. Until Candle gets a dedicated op for this, this is the best we can do and it should be reasonably fast on CUDA/Metal.
let result_zeroed = result.scatter_add(&mask, result.neg()?)?;
let result = result.scatter_add(&mask, new_result)?;

So, the final code might look like this:

let new_result = (result.gather(&mask, D::Minus1)? / count.gather(&mask)?.repeat(&[1, c])?)?;
let result_zeroed = result.scatter_add(&mask, result.neg()?)?;
let result = result.scatter_add(&mask, new_result)?;

EricLBuehler avatar Jul 19 '24 23:07 EricLBuehler

@EricLBuehler What about dimensions, both gather and scatter_add always require dimension value.(gather needs two arguments, and scatter_add needs three arguments). Seems that this Dim argument is not optional.

wiktorkujawa avatar Jul 20 '24 09:07 wiktorkujawa

@wiktorkujawa what are the shapes of result, mask, and count?

EricLBuehler avatar Jul 20 '24 15:07 EricLBuehler

@EricLBuehler Result and count are something like this:

result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype)  # [H, W, C]
count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype)  # [H, W, 1]
mask = (count.squeeze(-1) > 0)

where:

  • H and W are typically equal 256 or 512(it's a texture size, like png or jpg size)
  • C = values.shape[-1] where values is tensor connected to some rgbs transformations(to rasterize visible Gaussians to image, it has quite a lot of transformations in the way)

wiktorkujawa avatar Jul 20 '24 18:07 wiktorkujawa