caffe2_cpp_tutorial icon indicating copy to clipboard operation
caffe2_cpp_tutorial copied to clipboard

StopGradient operator doesn't stop AddGradientOps from creating gradient operators

Open breadbread1984 opened this issue 7 years ago • 7 comments
trafficstars

ModelUtil::AddGradientOps function is supposed to stop creating gradient operators for operators before StopGradient operator. But it doesn't.

breadbread1984 avatar Dec 02 '17 16:12 breadbread1984

Correct, I didn't implement that yet. Adding it to the roadmap. Thanks!

leovandriel avatar Dec 04 '17 17:12 leovandriel

I think the following code could do the trick.

        auto ops = net.op();
        std::set<std::string> open,newopen;
	//delete all StopGradient ops
	ops.erase(remove_if(ops.begin(),ops.end(),
		[&](const OperatorDef& op){
			if(op.type() == "StopGradient") {
				for(auto & input : op.input()) open.insert(input);
					return true;
			}
			return false;
		}
	),ops.end());
	//delete all precursor of StopGradient ops
	while(false == open.empty()) {
		ops.erase(remove_if(ops.begin(),ops.end(),
			[&](const OperatorDef& op) {
				for(auto & output: op.output())
					if(open.end() != open.find(output)) {
						for(auto & input: op.input())
							if(std::find(net.external_input().begin(),net.external_input().end(),input) == net.external_input().end())
								newopen.insert(input);
						return true;
					}
				return false;
			}
		),ops.end());
		open = newopen;
		newopen.clear();
	}

breadbread1984 avatar Dec 07 '17 02:12 breadbread1984

Hi @breadbread1984, I added support for StopGradient in 6d9b16be6d5. Let me know if that works.

leovandriel avatar Dec 12 '17 19:12 leovandriel

ok

breadbread1984 avatar Dec 13 '17 05:12 breadbread1984

I tested it. It works, Thx

breadbread1984 avatar Dec 13 '17 12:12 breadbread1984

AddOptimizerOps adds weight update operators for all parameters even if the parameters are beyond StopGradient.

breadbread1984 avatar Dec 30 '17 11:12 breadbread1984

thanks, i'll take a look

leovandriel avatar Jan 03 '18 18:01 leovandriel