oneflow
oneflow copied to clipboard
nn.Module to_global 接口是否需要加一个参数来跳过 global tensor
一些搭好并配好 sbp 的 global 模型想做扩展时,可能会迁移过来一些 local module,组成一个 global module 和 local module 混合的情况,这时如果想对其中的 local module 做统一的 to_global 操作并跳过 local tensor 的话,需要用内部接口 _apply
并手动写一个处理函数,这种方式对于一般用户是比较难做到的。
需要讨论一下是否有必要对 to_global 加一个参数(skip_global_tensor=False
之类的,参数名可以再讨论),来表示是否在 to_global 时,跳过已经是 global tensor 的 Parameter。
import oneflow as flow
placement = flow.placement("cpu", [0, 1])
B = flow.sbp.broadcast
S0 = flow.sbp.split(0)
class Model(flow.nn.Module):
def __init__(self):
super().__init__()
self.param = flow.nn.Parameter(flow.ones(4, 4).to_global(placement, B))
def forward(self, x):
return x + self.param
m = Model()
m.new = flow.nn.Parameter(flow.zeros(4, 4))
print(m.new.is_global)
def local_to_global(x):
if x.is_global:
return x
else:
return x.to_global(placement, S0)
m._apply(local_to_global)
print(m.new.is_global)
这里是不是还有额外的一个假设:
如果 module 里有 local 和 global tensor, module.to_global(skip_global_tensor=true)
, module.to_global(only_apply_local=true)
, module.local_to_global()
, local to global 参数里的 Placement,和原本 module 里的 global tensor 的 Placement 要一致?(只是 sbp 有可能不同)
如果不跳过会有什么后果吗
如果不跳过会有什么后果吗
不跳过的话就把已经配好 sbp 的 module 也修改了,比如上面例子里,如果直接用 m.to_global(placement, S0)
的话,那 self.param 也变成了 S0。
这里是不是还有额外的一个假设:
如果 module 里有 local 和 global tensor,
module.to_global(skip_global_tensor=true)
,module.to_global(only_apply_local=true)
,module.local_to_global()
, local to global 参数里的 Placement,和原本 module 里的 global tensor 的 Placement 要一致?(只是 sbp 有可能不同)
这里应该不用处理,因为原本 module 里的 placement 可能就是不一样的。
这个需求肯定是存在的。
一一指定每个 tensor to global 的结果是最灵活的,但是也是最麻烦的。 之前的 module to global 假设了所有 tensor(weight) 的 placement 、sbp 都一致。但实际上很有可能不一致(混合并行情况下), 如果每次都让用户来写,也不太方便。 提供一些公共的方法可以组合解决用户的需求是合适的。但就看要到哪个程度。
是不是提供modify_global或者transform_global的接口做个区分比较好。已经是global的再to global就报错?
这个需要结合一个具体的例子来看看,看这种两阶段的 to_global 配置是否普遍、是否可以作为一种通用的功能。
如果这种使用模式非常不规整,可能还是需要让用户在 tensor 级别去做。
不建议在 to_global 里面加这个参数,这种接口最好保持纯洁性和功能单一。用别的接口可以考虑,但我觉得也没必要,上面 m.new = flow.nn.Parameter(flow.zeros(4, 4))
可以直接在 Parameter 后面接 to_global 也可以另起一行对 m.new 做 to_global,感觉上没必要提供这种功能。