juice icon indicating copy to clipboard operation
juice copied to clipboard

Unclear how to correctly compute dot product

Open daniil-berg opened this issue 2 years ago • 1 comments

I am very new to Rust in general and to the juice/coaster crates in particular. Undoubtedly the following issue is due to my poor understanding, which is why I am not marking this as a bug.

Leaning heavily on the usage example provided here, I am trying to compute the dot product of two tensors (a 2x2 matrix and a 2x1 vector to be precise).

This is my output:

Data:
x = [1.0, 2.0]
w = [[1.0, 1.0],
     [3.0, -1.0]]
Expected result:
w*x = [3.0, 1.0]
Actual result:
w*x = [3.0, 0.0]

Here is the code:

use coaster::backend::Backend;
use coaster::IFramework;
use coaster::frameworks::cuda::{Cuda, get_cuda_backend};
use coaster::frameworks::native::{Cpu, Native};
use coaster::frameworks::native::flatbox::FlatBox;
use coaster::tensor::SharedTensor;
use coaster_blas::plugin::Dot;


fn write_to_memory<T: Copy>(mem: &mut FlatBox, data: &[T]) {
    let mem_buffer: &mut[T] = mem.as_mut_slice::<T>();
    for (index, datum) in data.iter().enumerate() {
        mem_buffer[index] = *datum;
    }
}


pub fn main() {
    let backend: Backend<Cuda> = get_cuda_backend();
    let native: Native = Native::new();
    let cpu: Cpu = native.new_device(native.hardwares()).unwrap();

    let mut x: SharedTensor<f32> = SharedTensor::<f32>::new(&(2, 1));
    let mut w: SharedTensor<f32> = SharedTensor::<f32>::new(&(2, 2));
    let mut wx: SharedTensor<f32> = SharedTensor::<f32>::new(&(2, 1));

    let x_values: &Vec<f32> = &vec![1f32, 2.0];
    let w_values: &Vec<f32> = &vec![1f32,  1.0,
                                     3.0, -1.0];

    println!("Data:");
    println!("x = {:?}", x_values);
    println!("w = [{:?},\n     {:?}]", &w_values[..2], &w_values[2..]);
    println!("Expected result:");
    println!("w*x = [3.0, 1.0]");

    write_to_memory(x.write_only(&cpu).unwrap(), x_values);
    write_to_memory(w.write_only(&cpu).unwrap(), w_values);

    backend.dot(&w, &x, &mut wx).unwrap();
    println!("Actual result:");
    println!("w*x = {:?}", wx.read(&cpu).unwrap().as_slice::<f32>());
}

What am I doing wrong?

daniil-berg avatar Aug 07 '22 12:08 daniil-berg

I'll take a deeper look next week, I am unfortunately mid-move

drahnr avatar Aug 09 '22 12:08 drahnr