nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

GPT with UNet architecture gets the loss down to ~1.0 with no significant computation costs.

Open englertbruno opened this issue 2 years ago • 14 comments

TLDR: This is a bit of a tangent of what the original repo is for (education), but I think it's an interesting finding. Basically by using the following architecture adopted from the UNet paper, we can use significantly deeper models and the computation cost is just slightly higher.

self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            wpe=nn.Embedding(config.block_size, config.n_embd),
            drop=nn.Dropout(config.dropout),
            compressing=nn.ModuleList([
                nn.ModuleList([Block(config) for _ in range(5)] + [BlockCompressing(config)]),
                nn.ModuleList([Block(config) for _ in range(2)] + [BlockCompressing(config)]),
                nn.ModuleList([Block(config) for _ in range(2)] + [BlockCompressing(config)]),
                nn.ModuleList([Block(config) for _ in range(2)] + [BlockCompressing(config)]),
                nn.ModuleList([Block(config) for _ in range(2)] + [BlockCompressing(config)])
            ]),
            middle=nn.ModuleList(
                [Block(config) for _ in range(100)]),
            expanding=nn.ModuleList([
                nn.ModuleList([BlockExpanding(config)] + [Block(config) for _ in range(2)]),
                nn.ModuleList([BlockExpanding(config)] + [Block(config) for _ in range(2)]),
                nn.ModuleList([BlockExpanding(config)] + [Block(config) for _ in range(2)]),
                nn.ModuleList([BlockExpanding(config)] + [Block(config) for _ in range(2)]),
                nn.ModuleList([BlockExpanding(config)] + [Block(config) for _ in range(2)]),
            ]),
            ln_f=nn.LayerNorm(config.n_embd),
        ))

The idea: Using the same word block size (in our case 1024), seems pretty wasteful. Words like "the", "a", "she" etc. don't contain too much information, meaning we could compress the input significantly. This could be applied to sentences too: most sentences don't contain that much information and could be stored in fewer vectors.

The other observation is that most of the words are kept (because of the low drop out rate), and only needs to be shifted. Shifting can be done with only one block and we can assume that the missing words can be expressed with much less words than 1024.

Using all these assumptions, most of the information is copied from the input to the output. This brings us to the UNet architecture. UNet was designed for segmentation and it's architecture was designed on similar assumptions.

Adopting a UNet style architecture has one major advantage: a shorter word block size. This brings down the computation of the attention layer from (10241024) = 1048576, to (3232) = 1024. This is a significant reduction in computation, meaning the middle layers can be much much more deeper.

I don't know how well this architecture can be used for other tasks, like translation, vision or chatGPT. But for this particular case, it seems to work well. The idea is also not new, after few hours of searching, I found this paper from 2019: https://arxiv.org/pdf/1910.10488v1.pdf.

You can find the full model here: https://github.com/englertbruno/nanoGPT/blob/master/model_gpt_unet.py

englertbruno avatar Jan 22 '23 10:01 englertbruno

I find this very interesting! Thank you for sharing. Would be nice to see how well it performs on vision tasks.

sandorkonya avatar Jan 24 '23 11:01 sandorkonya

Bit undersold it: at the time of writing the issue, the training was around 75k iters. After training it further, is achieves ~0.75 loss (with 600k iteration)

englertbruno avatar Jan 26 '23 11:01 englertbruno

This is a pretty significant difference! Would love to see in action on HF ;)

sandorkonya avatar Jan 26 '23 19:01 sandorkonya

It looks like encoder => gpt => decoder maybe we can be trained alone? nanoBERT + nanoGPT

zhzLuke96 avatar Feb 03 '23 09:02 zhzLuke96

Is that training of validation loss? and if you generate ±10 texts are they grammatically correct or garbled? I could imagine compression and decompression gets you an almost impressionistic result where the detail (word level structure) is off ut the loss is good as the right tokens are in the awnsers (running myself to find out but thats going to be a while on my GPU...)

MichelNivard avatar Feb 06 '23 10:02 MichelNivard

Has anyone tested the output? Is it gibberish or meaningful?

jyviko avatar Feb 07 '23 20:02 jyviko

I tried this unet architecture, and it doesn't seem to work well for text generation

trained model will output a lot of noise. According to my test, this noise (garbled characters) is strongly related to block_size/window_size that is to say, the unet model tends to require input with the same size as the training data, otherwise it will force complete the insufficient parts use garbled characters

zhzLuke96 avatar Feb 07 '23 21:02 zhzLuke96

that is to say, the unet model tends to require input with the same size as the training data, otherwise it will force complete the insufficient parts use garbled characters

Could something similar to NovelAI's aspect ratio bucketing work then?

decahedron1 avatar Feb 08 '23 00:02 decahedron1

that is to say, the unet model tends to require input with the same size as the training data, otherwise it will force complete the insufficient parts use garbled characters

Could something similar to NovelAI's aspect ratio bucketing work then?

This method can indeed improve the performance of the model. I simply tested it Although I didn't completely copy the bucketing training plan, I used random window_size for training, which can indeed reduce the output of gibberish But still didn't get good results, only word-level learning (based on tiktoken's train data), can output the correct word, but did not generate the correct sentence

However, I just trained 500mb network text collected by myself, 10k steps. Maybe larger data and training steps can be useful?


If someone wants to continue to improve, I suggest trying some other position embedding methods

zhzLuke96 avatar Feb 10 '23 15:02 zhzLuke96

Sorry for the noise, but I have one last question.

But still didn't get good results, only word-level learning (based on tiktoken's train data), can output the correct word, but did not generate the correct sentence

Did it output gibberish, or was the output passable (just not desired)?

decahedron1 avatar Feb 10 '23 16:02 decahedron1

Sorry for the noise, but I have one last question.

But still didn't get good results, only word-level learning (based on tiktoken's train data), can output the correct word, but did not generate the correct sentence

Did it output gibberish, or was the output passable (just not desired)?

still out garbled characters, but the correct word can be spelled


I can show you some examples (but I am training on my chinese-corpus, so i just can showcase for chinese result)

unet version output:

Is ChatGPT dangerous?
-��364EG��M����J������硫���cn Y����上�/)�.ob���LA��()、�
�110-://�����●�方��

unet with random window_size version output:

苹果手机最新发布会,
该�手备�播情报单变氆在�拶尺对狍而发型節正开次�反尿尽性�屏送"全网在韾受是富可件上的采��,从追报��电�消费�已彤亚�标贃销近,将容得要背�的因逊入去户有实立用或。此成�势张就�质,,

source nanoGPT version ouput:

苹果手机最新发布会,
持续超过75年。 在20世纪70年代起,使用小型活塞偏离热的系统,以致系统节的变动,使用小型苹果装备更加容易自动切入系统。  

zhzLuke96 avatar Feb 10 '23 16:02 zhzLuke96

Are you padding the input? If you are not padding the input to be a multiple of 32 (5 compression blocks, so: 2^5 = 32), it will have issues.

englertbruno avatar Feb 10 '23 16:02 englertbruno

Are you padding the input? If you are not padding the input to be a multiple of 32 (5 compression blocks, so: 2^5 = 32), it will have issues.

I don't have padding input, but I padding tensor in forward method of compress and expand, align shape, and fill by zero

zhzLuke96 avatar Feb 10 '23 16:02 zhzLuke96

Other relevant paper about hierarchical processing in Transformer decoder models would be this one

PiotrNawrot avatar Mar 17 '23 20:03 PiotrNawrot

In the past few days I finally had some time to look into this architecture again and realized that the reason the loss is so low is because the causality is broken during the compression/expansion phase. In the self attention part, there is a triangular mask prohibiting the network to look forward in time, but the unet architecture brakes this. This is also the reason why it is so bad at generating sentences. Sorry for wasting your time. 😄

englertbruno avatar Apr 30 '23 20:04 englertbruno