DiffuSeq
DiffuSeq copied to clipboard
Randomly Initialized embeddings?
Hi, thanks for sharing your implementations, it is really helpful and very clean to follow.
I have one question about the embedding used in your model. There are 3 options discussed in the Diffusion-LM paper: 1. fixed randomly initialized embedding. 2. fiexd embeddings initialized from a PLM (like BERT). or 3. E2E embedding like Diffusion-LM. I did not find which one you used in your paper.
But from your codes, It seems that you use randomly initialized embeddings, which is different from Diffusion-LM? Specifically, I find the input ids are embedded into 128-d random embeddings in the data loading process: https://github.com/Shark-NLP/DiffuSeq/blob/8bfafcbb26df218073b8117234afb9de9dfcbec9/basic_utils.py#L71
Please correct me if I am wrong.
Hi, we randomly initialized the embeddings and trained them end-to-end, so it is the same setting with Diffusion-LM. We also tried the second setting, which is compared with joint E2E training in our paper (Table 3). The line you quote loads the 128d random embeddings but they are not used in the later training, and we keep this interface in case you need to initialize them in other settings. For DiffuSeq, the actually used word embedding is:
https://github.com/Shark-NLP/DiffuSeq/blob/8bfafcbb26df218073b8117234afb9de9dfcbec9/diffuseq/transformer_model.py#L55
For future reference, it seems the randomly initialized embeddings are used, https://github.com/Shark-NLP/DiffuSeq/blob/901f8604d619c1923d69e57bd11894523309aab8/diffuseq/text_datasets.py#L192
But it's finally not used, https://github.com/Shark-NLP/DiffuSeq/blob/901f8604d619c1923d69e57bd11894523309aab8/diffuseq/gaussian_diffusion.py#L596-L600
According to your code,the randomly initialized the embeddings is passed as x_start, then it is assigned to a local variable x_start_fix, but you don't use x_start_fix in the later code. If I want to use fixed randomly initialized embedding for my experiment, I just need to set x_start_mean = x_start_fix . Is it right? Please correct me if I am wrong.