torchrec
torchrec copied to clipboard
weight init in embedding_config should depend on embedding_dim
What
Currently by default, the min and max weight init for EmbeddingConfig is "sqrt(1 / self.num_embeddings)" instead of "sqrt(1 / self.embedding_dim)":
https://github.com/zainhuda/torchrec/blob/4ea42e2f44f4485fd35865323236a0b277873216/torchrec/modules/embedding_configs.py#L175
I believe it should be the latter, from embedding vector normalization perspective?
Hmm, fair point. Put out a PR?
Hmm, fair point. Put out a PR?
Sg. Any unit test scripts that I should be changing at the same time? @PaulZhang12
PR here: https://github.com/pytorch/torchrec/pull/1902
for context I think DLRM init it with 1 / num_embeddings
https://github.com/facebookresearch/dlrm/blob/main/dlrm_s_pytorch.py#L281
for context I think DLRM init it with 1 / num_embeddings
https://github.com/facebookresearch/dlrm/blob/main/dlrm_s_pytorch.py#L281
Thanks for the cross ref. Does the switch to embedding_dim make sense to you in general? @henrylhtsang
@di-wnd I think @PaulZhang12 brought it up, and it seems the result of that discussion is that its better to change user input than to change the default
@di-wnd I think @PaulZhang12 brought it up, and it seems the result of that discussion is that its better to change user input than to change the default
@henrylhtsang Sry but I actually disagree. The default behavior should be at least reasonable such that any training program that leverages this default implementation won't suffer from a "bad init" issue that could potentially degrade model performance. Imagine that in practice most of the users won't bother manually setting "weight_init_max" and "weight_init_min" for this dataclass, and in case when num_embeddings become large (say beyond 1M) all the corresponding vector values will be initialized with a 'tiny' number, which might have an impact on the training efficiency.
@di-wnd I actually don't know if it would affect training efficiency. Maybe worth checking if it would affect accuracy and metrics. The DLRM model leads me to lead it needs more evidence
EDIT: fixing typo