megablocks
megablocks copied to clipboard
ParallelDroplessMLP initialises self.mlp twice
What the title says. In layers/dmoe.py
:
class ParallelDroplessMLP(moe.ParallelMLP):
def __init__(self, args : Arguments):
super(ParallelDroplessMLP, self).__init__(args) # <-- first init!
self.hidden_size = args.hidden_size
self.ffn_hidden_size = mpu.features_per_rank(args)
self.blocking = 128
self.mlp = dmlp_registry.get(args) # <-- second init!
As a subclass of moe.ParallelMLP
, ParallelDroplessMLP
first initialises self.mlp
in super().__init__()
(at layers/moe.py
):
class ParallelMLP(torch.nn.Module):
def __init__(self, args : Arguments):
# ... omitted ...
# Expert MLP.
self.mlp = mlp.MLP(args)
This causes extra initialisation time && init memory usage, as the weights created in this init are immediately overwritten by new weights created via self.mlp = dmlp_registry.get(args)
.
Apologies in advance if this double-init process is actually crucially important to the mechanics of the library; I personally did not observe anything breaking after commenting out the first initialisation.