oneflow icon indicating copy to clipboard operation
oneflow copied to clipboard

nn.Module to_global 接口是否需要加一个参数来跳过 global tensor

Open wyg1997 opened this issue 2 years ago • 8 comments

一些搭好并配好 sbp 的 global 模型想做扩展时,可能会迁移过来一些 local module,组成一个 global module 和 local module 混合的情况,这时如果想对其中的 local module 做统一的 to_global 操作并跳过 local tensor 的话,需要用内部接口 _apply 并手动写一个处理函数,这种方式对于一般用户是比较难做到的。

需要讨论一下是否有必要对 to_global 加一个参数(skip_global_tensor=False 之类的,参数名可以再讨论),来表示是否在 to_global 时,跳过已经是 global tensor 的 Parameter。

import oneflow as flow

placement = flow.placement("cpu", [0, 1])
B = flow.sbp.broadcast
S0 = flow.sbp.split(0)

class Model(flow.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = flow.nn.Parameter(flow.ones(4, 4).to_global(placement, B))

    def forward(self, x):
        return x + self.param

m = Model()

m.new = flow.nn.Parameter(flow.zeros(4, 4))
print(m.new.is_global)

def local_to_global(x):
    if x.is_global:
        return x
    else:
        return x.to_global(placement, S0)

m._apply(local_to_global)
print(m.new.is_global)

wyg1997 avatar Aug 08 '22 04:08 wyg1997

这里是不是还有额外的一个假设:

如果 module 里有 local 和 global tensor, module.to_global(skip_global_tensor=true), module.to_global(only_apply_local=true), module.local_to_global(), local to global 参数里的 Placement,和原本 module 里的 global tensor 的 Placement 要一致?(只是 sbp 有可能不同)

chengtbf avatar Aug 08 '22 04:08 chengtbf

如果不跳过会有什么后果吗

hjchen2 avatar Aug 08 '22 04:08 hjchen2

如果不跳过会有什么后果吗

不跳过的话就把已经配好 sbp 的 module 也修改了,比如上面例子里,如果直接用 m.to_global(placement, S0) 的话,那 self.param 也变成了 S0。

wyg1997 avatar Aug 08 '22 05:08 wyg1997

这里是不是还有额外的一个假设:

如果 module 里有 local 和 global tensor, module.to_global(skip_global_tensor=true), module.to_global(only_apply_local=true), module.local_to_global(), local to global 参数里的 Placement,和原本 module 里的 global tensor 的 Placement 要一致?(只是 sbp 有可能不同)

这里应该不用处理,因为原本 module 里的 placement 可能就是不一样的。

wyg1997 avatar Aug 08 '22 05:08 wyg1997

这个需求肯定是存在的。

一一指定每个 tensor to global 的结果是最灵活的,但是也是最麻烦的。 之前的 module to global 假设了所有 tensor(weight) 的 placement 、sbp 都一致。但实际上很有可能不一致(混合并行情况下), 如果每次都让用户来写,也不太方便。 提供一些公共的方法可以组合解决用户的需求是合适的。但就看要到哪个程度。

chengtbf avatar Aug 08 '22 05:08 chengtbf

是不是提供modify_global或者transform_global的接口做个区分比较好。已经是global的再to global就报错?

jackalcooper avatar Aug 08 '22 05:08 jackalcooper

这个需要结合一个具体的例子来看看,看这种两阶段的 to_global 配置是否普遍、是否可以作为一种通用的功能。

如果这种使用模式非常不规整,可能还是需要让用户在 tensor 级别去做。

strint avatar Aug 08 '22 06:08 strint

不建议在 to_global 里面加这个参数,这种接口最好保持纯洁性和功能单一。用别的接口可以考虑,但我觉得也没必要,上面 m.new = flow.nn.Parameter(flow.zeros(4, 4)) 可以直接在 Parameter 后面接 to_global 也可以另起一行对 m.new 做 to_global,感觉上没必要提供这种功能。

leaves-zwx avatar Aug 08 '22 15:08 leaves-zwx