caffe2_cpp_tutorial
caffe2_cpp_tutorial copied to clipboard
StopGradient operator doesn't stop AddGradientOps from creating gradient operators
ModelUtil::AddGradientOps function is supposed to stop creating gradient operators for operators before StopGradient operator. But it doesn't.
Correct, I didn't implement that yet. Adding it to the roadmap. Thanks!
I think the following code could do the trick.
auto ops = net.op();
std::set<std::string> open,newopen;
//delete all StopGradient ops
ops.erase(remove_if(ops.begin(),ops.end(),
[&](const OperatorDef& op){
if(op.type() == "StopGradient") {
for(auto & input : op.input()) open.insert(input);
return true;
}
return false;
}
),ops.end());
//delete all precursor of StopGradient ops
while(false == open.empty()) {
ops.erase(remove_if(ops.begin(),ops.end(),
[&](const OperatorDef& op) {
for(auto & output: op.output())
if(open.end() != open.find(output)) {
for(auto & input: op.input())
if(std::find(net.external_input().begin(),net.external_input().end(),input) == net.external_input().end())
newopen.insert(input);
return true;
}
return false;
}
),ops.end());
open = newopen;
newopen.clear();
}
Hi @breadbread1984, I added support for StopGradient in 6d9b16be6d5. Let me know if that works.
ok
I tested it. It works, Thx
AddOptimizerOps adds weight update operators for all parameters even if the parameters are beyond StopGradient.
thanks, i'll take a look