[WIP][Dy2St] Use descriptor for convert a layer to avoid circular reference
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
[!TIP]
图中
a -> b,表示 b 持有 a(的引用)
原来 to_static(model) 的实现是 model.forward = decorated(model.forward),这导致了一个循环引用,即 model.forward 引用了 static_fn,static_fn.class_instance、static_fn.__wrapped__.__self__ 等位置引用了 model
这在大多数情况也没啥问题,但在 #68743 中的 case 就会因为延迟 gc 而 OOM 了
不过仔细一想,为什么 Python 的其他「方法」没有这个问题呢?
比如 model 通过 model.forward 持有 bound method,而 bound method 又通过 method.__self__ 持有 model
不过再仔细一想就很清楚了,model 实际并没有持有 model.forward,model.forward 是在这个 LOAD_ATTR/LOAD_METHOD 时候通过 model.__class__.forward 这个 unbound method(或者说 descriptor)通过 bind model(也就是 model.__class__.forward.__get__(model, model.__class__))得到的,model 实际上并没有持有这样一个方法,而是在每次重新 bind 的,这可以通过每次得到的 bound method 的 id 都不同来验证
那么,自然也就想到我们的 static function 也用同样的方法了,对于 to_static(model) 来说,换成 model.__class__.forward = decorated(model.__class__.forward),这样 model.forward 就是每次重新 bind 的结果了
不过要注意不能影响其它实例的 forward 方法,因此还需要加一个实例的白名单,只有 to_static 时候传入的实例才会用 static function bind,否则直接用原来的 dygraph function bind 得到动态图的 bound method
PCard-66972
你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.