burn icon indicating copy to clipboard operation
burn copied to clipboard

0-Dimension tensor manipulation

Open gcesars opened this issue 1 year ago • 4 comments

I am unsure if this is already supported, but I could not find examples or clear ways to achieve simple tensor manipulations.

  1. Create and convert 0-dimension Tensors (aka scalars). I tried D = 0, and it panicked.
  2. Select individual 0-dimension Tensors from 1-dimension ones a[i] or a.get(i).
  3. Quick and easy 0-dimension Tensors from and to primitive let t = Tensor<B,0>::from(0.5) let f: f64 = t.into() if t > 0.9 {do_stuff()}

I'm asking this because when you work with more complex algorithms, it is common to have scalars defining the flow, which are individual values from computed tensors.

A simple dummy example that I could not achieve in a reasonable way:

let w = large 1-dimension tensor with majority of the elements near 0
let mask = w.greater_equal_elem(0.01)
let idx = can't find the operation that gets the indices of mask == true or directly from w skipping mask
idx.iter().map/fold(do_stuff())

This is one dummy example, but, in general, manipulate Tensors with burn is hard.

For reference, the i(..) method in tch is extremely convenient https://docs.rs/tch/latest/tch/struct.Tensor.html#method.i-1. How can the users do this kind of manipulations in Burn?

gcesars avatar Jan 07 '24 06:01 gcesars

Scalars can be represented as tensors of rank 1 with a shape of 1: Tensor<B, 1>, otherwise, any Rust native type will also work (f32, i32, etc.) since they all implement the Element trait.

nathanielsimard avatar Jan 08 '24 21:01 nathanielsimard

Here is a snippet (Armijo stop condition) of something that is not clear at all how to achieve:

use burn::backend::{Autodiff, NdArray};
use burn::tensor::backend::Backend;
use burn::tensor::{Float, Tensor};

type Bad = Autodiff<NdArray>;

fn obj<B: Backend>(x: Tensor<B, 1, Float>) -> Tensor<B, 1, Float> {
    (x.clone() * x).sum()
}

fn main() {
    let t = Tensor::<Bad, 1, Float>::from_data(vec![1.0, 2.0, 3.0, 4.0].as_slice())
        .set_require_grad(true);

    let mut grads = obj(t.clone()).backward();
    let g_t = t.grad(&grads).unwrap();
    t.grad_remove(&mut grads);

    let c = 0.001;
    let mut alpha = 1.0;
    let m = (g_t.clone() * g_t.clone()).sum();

    let obj_val = obj(t.clone().inner());
    let mut x = t.clone().inner() + (g_t.clone() * (-alpha));

    let mut condition = obj(x.clone())
        .greater(obj_val.clone() + m.clone() * (-alpha * c))
        .to_data()
        .value[0];

    while condition {
        alpha *= 0.5;
        x = t.clone().inner() + (g_t.clone() * (-alpha));
        condition = obj(x.clone())
            .greater(obj_val.clone() + m.clone() * (-alpha * c))
            .to_data()
            .value[0];
    }

    println!("alpha: {}", alpha);
    println!("x: {:?}", x);
    println!("obj(x): {}", obj(x));
}

Is this the expected way to work with scalars? If so, I recommend a section in the book on tensor manipulation overall, took me a lot of guess and trial and error for such simple operation.

Replacing .to_data().value[0] by into_scalar() doesn't compile:

error[E0599]: the method `into_scalar` exists for struct `Tensor<NdArray, 1, Bool>`, but its trait bounds were not satisfied

gcesars avatar Jan 10 '24 00:01 gcesars

into_scalar is only available for the numeric types (Float and Int tensors), but I guess we should have something similar for Bool tensors. Maybe have an into_elem() for all tensor types instead?

nathanielsimard avatar Jan 16 '24 15:01 nathanielsimard

As a user, having different methods would be confusing. Although seems a bit janky, Candle has:

to_scalar()
to_vec1()
to_vec2()
to_vec3()

from_vec()
from_slice()
from_iter()

Those are extremely convenient when writing applications or algorithms that need to operate with other libraries. Also, from/to_ndarray would be amazing since ndarray seems to be the numpy of rust.

gcesars avatar Jan 16 '24 16:01 gcesars

I have implemented into_scalar for bool tensor as well. I don't remember as part of which PR but it's on the main currently.

antimora avatar Mar 29 '24 01:03 antimora