RustQuant icon indicating copy to clipboard operation
RustQuant copied to clipboard

`autodiff`: add support for nalgebra matrices/vectors.

Open avhz opened this issue 2 years ago • 10 comments

avhz avatar Jul 03 '23 20:07 avhz

Opps, just realised there is an issue for it and you assigned to yourself.

This is my blocker of https://github.com/avhz/RustQuant/issues/63 and I'm trying to fix it as well.

nalgebra requires its matrix element to be 'static, so it's not easy to create a matrix full of our Variable structure as it contains a reference to Graph. Currently I'm refactoring Variable structure to use Rc<Graph> instead, but it leads me to other troublesome problems.

You can see my progress in https://github.com/avhz/RustQuant/compare/main...WenqingZong:RustQuant:wenqing/neural-networks Currently the code doesn't compile due to an error which I'm not sure how to fix, was meant to ask for your help but found you are working on it as well.

The Problem: In Variable.rs, after changing graph: &Graph to graph: Rc<Graph>, the graph.rs/var() method failed to compile as creating a Variable would take the ownership of a Graph instance.

WenqingZong avatar Jul 06 '23 23:07 WenqingZong

If I recall correctly, using Rc means that the user needs to do things like:

// Original
let z = x + y;

// Using `Rc`
let z = x.clone() + y.clone();

Which I think is not user friendly. Is this the problem you had ?

avhz avatar Jul 07 '23 09:07 avhz

If I recall correctly, using Rc means that the user needs to do things like:

// Original
let z = x + y;

// Using `Rc`
let z = x.clone() + y.clone();

Which I think is not user friendly. Is this the problem you had ?

Hmmm, ok, forget about the Rc solution then. But again the current Variable structure cannot be used to populate an nalgebra matrix as it has lifetime 'v and... it's really blocking neural net implementation.

WenqingZong avatar Jul 07 '23 12:07 WenqingZong

I could try ndarray and see if it is the same problem?

avhz avatar Jul 08 '23 09:07 avhz

I could try ndarray and see if it is the same problem?

superb

WenqingZong avatar Jul 08 '23 22:07 WenqingZong

I tried briefly this morning using ndarray and it seems they can hold custom types just fine, however I think we would need to implement a matmul etc ourselves to make it useful.

I can maybe take a closer look tomorrow.

avhz avatar Jul 09 '23 11:07 avhz

The trouble with ndarray is for matrix multiplication the type needs to impl Dot which I can't do since it's external.

Component-wise multiplication works fine, but can't do matrix multiplication I think.

#[cfg(test)]
mod test_ndarray {

    #[test]
    fn test_component_mul() {
        let g = crate::autodiff::Graph::new();

        let (a, b, c, d) = (g.var(1.), g.var(2.), g.var(3.), g.var(4.));
        let (e, f, g, h) = (g.var(5.), g.var(6.), g.var(7.), g.var(8.));

        // a = [[1, 2],
        //      [3, 4]]
        // b = [[5, 6],
        //      [7, 8]]
        let a = ndarray::array![[a, b], [c, d]];
        let b = ndarray::array![[e, f], [g, h]];

        // COMPONENT-WISE MULTIPLICATION
        // c = [[5 , 12],
        //      [21, 32]]
        let c = a * b;                                          // <--- This works fine.
        let c_values = c.map(|x| x.value);
        let c_expected = ndarray::array![[5., 12.], [21., 32.]];

        // MATRIX MULTIPLICATION
        // let dot = a.dot(&b);                                 // <--- This does not work.

        assert_eq!(c, c_expected);

        println!("c: {:?}", c);
        println!("c_values: {:?}", c_values);
        println!("c_expected: {:?}", c_expected);
    }
}

avhz avatar Jul 09 '23 15:07 avhz

It would work I think if I could impl num::One and num::Zero traits for the Variable type but I don't see how this could be done.

avhz avatar Jul 09 '23 16:07 avhz

It would work I think if I could impl num::One and num::Zero traits for the Variable type but I don't see how this could be done.

Would a global, static graph variable solve that?

WenqingZong avatar Jul 09 '23 23:07 WenqingZong

I tried that earlier but it required unsafe and then I still got a lifetime issue trying to implement the matrix multiplication.

You can see it in the latest commit in the ndarray.rs file.

avhz avatar Jul 09 '23 23:07 avhz