tch-rs icon indicating copy to clipboard operation
tch-rs copied to clipboard

is there some proper way to copy grad?

Open HyeokSuLee opened this issue 3 years ago • 0 comments

I'm trying to implement a3c. There's one Global VarStore and each agent's local VarStores. My implementation is like >

  1. agent action

  2. get rewards

  3. calc loss and grad

  4. copy grad to global var_store

  5. Is really hard, grad is tensor which internally stored in weight and bias tensor. So, i thought just copy that.

let global_vars = global.vs.variables();
let weight_path = "l1.weight";
let agent_tensor = agent_vs.variables().get(weight_path).unwarp();

let g_tensor = global_vars
                    .get_mut(&weight_path)
                    .expect("error when get global tensor with agent's path");
g_tensor.grad().copy_(agent.tensor.grad());

But the problem is g_tensor.grad() returns None ( actually , it returns unidentified tensor) and I cannot find any solution to manually initailize grad value. After struggling with 8 hours, there is no way to manually initialize, or tensor copy, varstore copy, shallow_clone... etc. Eventually, i've found this solution. But this is really messy. Are there any good solution to do this?

            for v in global
                .vs
                .variables_
                .lock()
                .unwrap()
                .trainable_variables
                .iter_mut()
            {
                // just do meaningless calculation with tensors and backward().
                //Then global weights now have grad values.
                //Now we can do `g_tensor.grad().copy_(&tensor.grad())`
                let zero = v
                    .tensor
                    .new_zeros(v.tensor.size().as_slice(), (Kind::Float, Device::Cpu));
                println!("zero: {:?}", &zero);
                let a = (v.tensor.shallow_clone() + zero)
                    .sum(Float)
                    .set_requires_grad(true);
                println!("a: {:?}", &a);
                a.backward();
                v.tensor.zero_grad();
                println!(" grad {:?}: ", v.tensor.grad());
            }

HyeokSuLee avatar Jul 04 '22 02:07 HyeokSuLee